Okay, I am going to have a look at coding a simple Generative Adversarial Network (GAN hereafter) .

They are generative in that they produce new data instances. The classification models we worked with earlier are discriminative models.

These networks are designed to train a model to produce some form of output that is indistinguishable for the training data. But there is a huge difference in how they are constructed versus the classification models we saw in the last few posts.

The Adversarial in the name, refers to the fact that there are two neural networks competing with each other in a zero-sum game. The Generator network attempts to create/output data that fools the other network into believing it is real data (i.e. from the training set). The Discriminator network attempts to distinguish the fake data produced by the Generator from the real data in the training set. The discriminator’s results are used to update both models (I think).

The part I find hard to understand is that the generator has, initially, no idea what it is going to have to produce. And, takes as its input a random sample of data from the latent space for the desired output. Slowly, somehow, learning to produce what the discriminator sees in the training data.

As a simple start, I am going to train a GAN to produce a sine curve. Hopefully, by the time I am done, things will become somewhat clearer.

Some Setup Code

This is the set up code for the module. May change, may get additions. But need somewhere to start.

# gan_sine.py
# Ver 0.1.0: 2024.03.24, rek, get started figuring this out
#  - train GAN to generate sine curve
#     use class to define and instantiate model (best practice?)
#     a bit of code copied from multi-cat_2.py, eg setup, model

import math, time
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn  as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as trf

# set seed for reproducibility
torch.manual_seed(73)

# will need this later
device = "cuda" if torch.cuda.is_available() else "cpu"

# some general parameters
trn_len = 1024
batch_sz = 32

plt_trn = True

# model from/to file
ld_model = False
sv_model = False
tst_model = False
sv_dir = Path("./sav")

# some model parameters, hyperparameters?
max_ep = 512    # maximum number of epochs for training
lr = 0.001      # learninng rate
do_p = 0.16     # Dropout probability
e_stop = 9      # number of epochs for early stop

Prepare Training Data

The training data will be composed of pairs (x, y). x is a value in the interval (0, 2π). y is the sine of x. Let’s generate the dataset and have a look at it.

# create training data
ds_train = torch.zeros(trn_len, 2)
ds_train[:, 0] = 2 * math.pi * torch.rand(trn_len)
ds_train[:, 1] = torch.sin(ds_train[:, 0])
lbl_train = torch.ones(trn_len)

# Let's have a look
if plt_trn:
  fig = plt.figure(figsize=(8, 6))
  plt.plot(ds_train[:, 0], ds_train[:, 1], ".", c="r")
  plt.xlabel("Values of x", fontsize=12)
  plt.ylabel("Sine of x", fontsize=12)
  plt.title("A sine curve", fontsize=18)
  plt.show()

And here’s what the training data looks like. Yup, looks like a sine curve to me.

plot showing values in the training set

Dataloader and Data Labels

We will be needing a dataloader for the training data. And we are going to need data labels for the training samples and the fake samples.

From the perspective of the discriminator, pretty clearly all the samples from our training set must be real. And all the ones from the generator must be fake. The real samples will all be labeled as 1s and the fakes all as 0s. The labels must be the same size as the data batches. And we’ll load them on the GPU ready for our code to use when appropriate.

# prepare training dataset
trn_loader = torch.utils.data.DataLoader(ds_train, batch_size=batch_sz, shuffle=True)

# create real/fake data labels
real_lbls = torch.ones(batch_sz, 1)
real_lbls = real_lbls.to(device)

fake_lbls = torch.zeros(batch_sz, 1)
fake_lbls = fake_lbls.to(device)

Define and Instantiate Models

The Discriminator is pretty much the same as the binary classifier from the first set of posts in this series on machine learning with PyTorch. It is afterall acting as a binary classifier. Though I will be adding a Dropout after each layer except the output layer. A change from the earlier classifier.

Dropout

This has proven to be an effective technique for regularization and preventing the co-adaptation of neurons as described in the paper Improving neural networks by preventing co-adaptation of feature detectors.

torch.nn > Dropout

The Generator will be even simpler still. Just a series of linear transformations with a non-linear activation function after all but the output layer. And 1 hidden layer less than the Discriminator.

Without activation functions, neural networks would simply be a series of linear transformations, which would limit their ability to learn complex patterns and relationships in data.

Activation Functions in PyTorch, Jason Brownlee

And, when instantiated we’ll load to the GPU ready for use.

# define Discriminator model class
class Discriminator(nn.Module):
  def __init__(self, do_p=0.3):
    super().__init__()
    self.model = nn.Sequential(
      nn.Linear(2, 256),
      nn.ReLU(),
      nn.Dropout(p=do_p),
      nn.Linear(256, 128),
      nn.ReLU(),
      nn.Dropout(p=do_p),
      nn.Linear(128, 64),
      nn.ReLU(),
      nn.Dropout(p=do_p),
      nn.Linear(64, 1),
      nn.Sigmoid()
    )
    
  def forward(self, x):
    outp = self.model(x)
    return outp


# define Generator model class, very simple this one
class Generator(nn.Module):
  def __init__(self):
    super().__init__()
    self.model = nn.Sequential(
      nn.Linear(2, 16),
      nn.ReLU(),
      nn.Linear(16, 32),
      nn.ReLU(),
      nn.Linear(32, 2),
    )
    
  def forward(self, x):
    outp = self.model(x)
    return outp


# instantiate the models, specify optimizer function
discrm = Discriminator()
discrm.to(device)
opt_d = torch.optim.Adam(discrm.parameters(), lr=lr)
genatr = Generator()
genatr.to(device)
opt_g = torch.optim.Adam(genatr.parameters(), lr=lr)

Early Stop

GANs represent a bit of a problem when it comes to determining how many epochs should be used for training. GANs are trained differently than the classification models we looked at previously. In effect the Generator just gets better and better with additional training. So trying to look for converging model parameters using some loss calculation on the discriminator is not going to work.

One common approach is to visually inspect the generated data instances. Calling it quits when we are happy with the output. But in our case, we are generating data matching a very specific formula. So, we can take the mean squared error of the fake data against that equation. The early stop class is pretty much identical to the one we used for the classifiers in the earlier posts. Something like the following.

# early stop generator performance function
# nice and simple arithmetic
def gen_perf(fakes):
  r_vals = torch.sin(fakes[:, 0])
  mse_loss = mse(fakes[:,1], r_vals)
  return mse_loss


# early stop class
class EarlyStop:
  def __init__(self, e_wait=100):
    self.e_wait = e_wait   # stop if notbttr >= wait
    self.notbttr = 0    # number of epochs without improvement
    # last best loss, start with big number is looking for min
    self.min_loss = float('inf')
  
  # vs_loss: current value of validation set average loss
  def is_stop(self, g_loss):
    if g_loss < self.min_loss:
      self.min_loss = g_loss
      self.notbttr = 0
    else:
      self.notbttr += 1
    return ((self.notbttr >= self.e_wait), self.notbttr)


# instantiate class
chk_stop = EarlyStop(e_stop)

Training the GAN

Some Helpful Functions

Okay, let’s get to the meat of the matter. To help keep the code tidy I am going to write three functions to facilitate the training loop. They will train the discriminator on real data, then on fake data and finally train the generator.

# okay some functions to facilitate training the gan
def train_disc_real(samples):
  # load 'real' samples to gpu
  samples = samples.to(device)
  # zero out the gradients on each iteration
  # don't want them accumulating over a number of training cycles or epochs
  opt_d.zero_grad()
  out_d = discrm(samples)
  # calulate loss and backpropogate to update weights and biases
  loss_d = loss_fn(out_d, real_lbls)
  loss_d.backward()
  opt_d.step()
  return loss_d


def train_disc_fake():
  # generate some fake data
  noise = torch.randn((batch_sz, 2))
  noise = noise.to(device)
  fakes = genatr(noise)
  # train discriminator on fake data
  opt_d.zero_grad()
  out_d = discrm(fakes)
  loss_d = loss_fn(out_d, fake_lbls)
  loss_d.backward()
  opt_d.step()
  return loss_d


def train_genr():
  # generate random tuples from latent 2-D space
  noise = torch.randn((batch_sz, 2))
  noise = noise.to(device)
  # train generator
  opt_g.zero_grad()
  fakes = genatr(noise)
  out_disc = discrm(fakes)
  # want to convince discriminator fakes are real, so must use real labels
  loss_g = loss_fn(out_disc, real_lbls)
  loss_g.backward()
  opt_g.step()
  return loss_g, fakes

I am also going to add a function to periodically output the losses for both the discriminator and generator. As well, the function will plot the current fake and real data and save that to a file (multiple files) so we can see how the generator progresses over time.

def chk_perf(epoch, loss_g, loss_d, n, do_es, es_cnt):
  ep = epoch + 1
  g_avg = loss_g.item() / n
  d_avg = loss_d.item() / n
  print(f"epoch {ep}: mean gen loss: {g_avg}, mean disc loss: {d_avg}, stop: {do_es}, estop cnt: {es_cnt}")
  # move fake samples from gpu to cpu and convert to numpy array
  f_samp = fakes.detach().cpu().numpy()
  plt.figure()
  plt.plot(ds_train[:, 0], ds_train[:, 1], ls="", marker=".", c="g", label="Real data")
  plt.plot(f_samp[:, 0], f_samp[:, 1], ls="", marker=".", c="k", label="Generated samples")
  plt.title(f"Epoch {ep}")
  plt.legend()
  plt.savefig(img_dir / f"gan_{ep}.png")
  plt.close()

Training Loop

And finally we are ready to train the model. With all of the above functions a pretty straightforward and tidy loop.

And, I am saving the models when done. Really only need to save the Generator model once training is done. But, I tend to be overly cautious. I am also using a different approach to saving the models than the approach I used with the earlier classification models.

st_tm = time.perf_counter()
for epoch in range(trn_len):
  # aggregate the losses over all the batches in each epoch
  agg_l_g = 0
  agg_l_d = 0
  for i, r_samp in enumerate(trn_loader):
    loss_d = train_disc_real(r_samp)
    agg_l_d += loss_d
    loss_d = train_disc_fake()
    agg_l_d += loss_d
    loss_g, fakes = train_genr()
    agg_l_g += loss_g
  g_perf = gen_perf(fakes).item()
  do_stop, e_cnt = chk_stop.is_stop(g_perf)
  if epoch == 0 or (epoch+1) % 25 == 0:
    chk_perf(epoch, agg_l_g, agg_l_d, i, do_stop, e_cnt)
  if do_stop:
    chk_perf(epoch, agg_l_g, agg_l_d, i, do_stop, e_cnt)
    break
nd_tm = time.perf_counter()
print(f"\ntime to train GAN : {nd_tm - st_tm}")

# Save models, using torchscript instead of model and parameters
g_script = torch.jit.script(genatr)
g_script.save(sv_dir / "gen_sin.pt")
d_script = torch.jit.script(genatr)
d_script.save(sv_dir / "disc_sin.pt")
(mclp-3.12) PS F:\learn\mcl_pytorch\chap3> python gan_sine.py
epoch 1: mean gen loss: 1.490629650297619, mean disc loss: 0.6553892710852245, stop: False, estop cnt: 0
epoch 25: mean gen loss: 0.7633150494287885, mean disc loss: 1.3762268793015253, stop: False, estop cnt: 24
epoch 50: mean gen loss: 0.7133721851167225, mean disc loss: 1.4013696095300099, stop: False, estop cnt: 24
epoch 75: mean gen loss: 0.7568924313499814, mean disc loss: 1.3840772840711806, stop: False, estop cnt: 11
epoch 100: mean gen loss: 0.7518328106592572, mean disc loss: 1.3785276867094494, stop: False, estop cnt: 36
epoch 125: mean gen loss: 0.7727527921161954, mean disc loss: 1.374955313546317, stop: False, estop cnt: 1
epoch 150: mean gen loss: 0.7781505887470548, mean disc loss: 1.3515293181888641, stop: False, estop cnt: 18
epoch 175: mean gen loss: 0.8378843882727245, mean disc loss: 1.330979241265191, stop: False, estop cnt: 3
epoch 200: mean gen loss: 0.7926055060492622, mean disc loss: 1.356121123783172, stop: False, estop cnt: 28
epoch 225: mean gen loss: 0.7780661204504589, mean disc loss: 1.3575050717308408, stop: False, estop cnt: 13
epoch 250: mean gen loss: 0.8212702312166729, mean disc loss: 1.3461049397786458, stop: False, estop cnt: 1
epoch 275: mean gen loss: 0.9695556277320498, mean disc loss: 1.2565888904389881, stop: False, estop cnt: 26
epoch 300: mean gen loss: 0.99287354000031, mean disc loss: 1.2805784921797494, stop: False, estop cnt: 51
epoch 325: mean gen loss: 1.1120256696428572, mean disc loss: 1.2367453196692089, stop: False, estop cnt: 76
epoch 350: mean gen loss: 1.1205902099609375, mean disc loss: 1.2092908828977555, stop: False, estop cnt: 20
epoch 375: mean gen loss: 1.2085493784102181, mean disc loss: 1.1767952328636533, stop: False, estop cnt: 45
epoch 400: mean gen loss: 1.1152263823009672, mean disc loss: 1.2751183888268849, stop: False, estop cnt: 4
epoch 425: mean gen loss: 1.131700909326947, mean disc loss: 1.2235913957868303, stop: False, estop cnt: 29
epoch 450: mean gen loss: 1.2373709300207714, mean disc loss: 1.2141070895724826, stop: False, estop cnt: 54
epoch 475: mean gen loss: 1.1914457290891618, mean disc loss: 1.2426160782102555, stop: False, estop cnt: 79
epoch 500: mean gen loss: 1.171293228391617, mean disc loss: 1.2177117968362474, stop: False, estop cnt: 104
epoch 525: mean gen loss: 1.2852269732762898, mean disc loss: 1.2131259252154638, stop: False, estop cnt: 21
epoch 550: mean gen loss: 1.3274083213200645, mean disc loss: 1.1802967616489954, stop: False, estop cnt: 46
epoch 575: mean gen loss: 1.3269411117311507, mean disc loss: 1.18731689453125, stop: False, estop cnt: 5
epoch 600: mean gen loss: 1.391893296014695, mean disc loss: 1.1478795853872148, stop: False, estop cnt: 30
epoch 625: mean gen loss: 1.2702027578202506, mean disc loss: 1.1980070083860368, stop: False, estop cnt: 55
epoch 650: mean gen loss: 1.2800313556005085, mean disc loss: 1.2014923095703125, stop: False, estop cnt: 80
epoch 675: mean gen loss: 1.1272286309136286, mean disc loss: 1.2705027262369792, stop: False, estop cnt: 105
epoch 700: mean gen loss: 1.2984246148003473, mean disc loss: 1.1907279604957217, stop: False, estop cnt: 130
epoch 725: mean gen loss: 1.2192605639260912, mean disc loss: 1.2454918755425348, stop: False, estop cnt: 155
epoch 750: mean gen loss: 1.4227878631107391, mean disc loss: 1.1736723884703621, stop: False, estop cnt: 180
epoch 775: mean gen loss: 1.2188325912233382, mean disc loss: 1.2328755212208582, stop: False, estop cnt: 205
epoch 800: mean gen loss: 1.3562516712007069, mean disc loss: 1.203795175703745, stop: False, estop cnt: 230
epoch 825: mean gen loss: 1.1876869807167658, mean disc loss: 1.228102305578807, stop: False, estop cnt: 255
epoch 850: mean gen loss: 1.2549417889307415, mean disc loss: 1.2278522309802828, stop: False, estop cnt: 280
epoch 875: mean gen loss: 1.5368859427315849, mean disc loss: 1.0935672578357516, stop: False, estop cnt: 305
epoch 900: mean gen loss: 1.3320416647290427, mean disc loss: 1.2002282521081349, stop: False, estop cnt: 330
epoch 925: mean gen loss: 1.4170021178230408, mean disc loss: 1.1270537603469122, stop: False, estop cnt: 355
epoch 950: mean gen loss: 1.3654512677873885, mean disc loss: 1.1822577582465277, stop: False, estop cnt: 380
epoch 975: mean gen loss: 1.22968994625031, mean disc loss: 1.205474126906622, stop: False, estop cnt: 405
epoch 1000: mean gen loss: 1.1179779294937375, mean disc loss: 1.2524605402870783, stop: False, estop cnt: 430
epoch 1025: mean gen loss: 1.3628724113343254, mean disc loss: 1.1470592438228546, stop: False, estop cnt: 455
epoch 1050: mean gen loss: 1.4263184562562004, mean disc loss: 1.168227059500558, stop: False, estop cnt: 480
epoch 1075: mean gen loss: 1.1816352965339783, mean disc loss: 1.2356126573350694, stop: False, estop cnt: 505
epoch 1100: mean gen loss: 1.3791729155040922, mean disc loss: 1.198869129968068, stop: False, estop cnt: 530
epoch 1125: mean gen loss: 1.29012940421937, mean disc loss: 1.2150079636346727, stop: False, estop cnt: 555
epoch 1150: mean gen loss: 1.3286929660373263, mean disc loss: 1.1638463338216145, stop: False, estop cnt: 14
epoch 1175: mean gen loss: 1.2029019310360862, mean disc loss: 1.2165484958224826, stop: False, estop cnt: 39
epoch 1200: mean gen loss: 1.4608482481941345, mean disc loss: 1.1648771497938368, stop: False, estop cnt: 64
epoch 1225: mean gen loss: 1.4565243191189237, mean disc loss: 1.1623199705093625, stop: False, estop cnt: 89
epoch 1250: mean gen loss: 2.005118354918465, mean disc loss: 0.9612814282614087, stop: False, estop cnt: 114
epoch 1275: mean gen loss: 1.6742328462146578, mean disc loss: 1.0521458217075892, stop: False, estop cnt: 139
epoch 1300: mean gen loss: 1.5375321403382316, mean disc loss: 1.0848882765997023, stop: False, estop cnt: 164
epoch 1325: mean gen loss: 1.281940399654328, mean disc loss: 1.1498070368691096, stop: False, estop cnt: 189
epoch 1350: mean gen loss: 1.6145350138346355, mean disc loss: 1.0683160206628224, stop: False, estop cnt: 214
epoch 1375: mean gen loss: 1.3544540405273438, mean disc loss: 1.1372032771034846, stop: False, estop cnt: 239
epoch 1400: mean gen loss: 1.498809572250124, mean disc loss: 1.1028814164419023, stop: False, estop cnt: 264
epoch 1425: mean gen loss: 1.7213522290426588, mean disc loss: 1.0377409193250868, stop: False, estop cnt: 289
epoch 1450: mean gen loss: 1.5813751220703125, mean disc loss: 1.0866085234142484, stop: False, estop cnt: 314
epoch 1475: mean gen loss: 1.6748748052687872, mean disc loss: 1.053771730453249, stop: False, estop cnt: 339
epoch 1500: mean gen loss: 1.3635297502790178, mean disc loss: 1.1253923688616072, stop: False, estop cnt: 364
epoch 1525: mean gen loss: 1.4777811443994915, mean disc loss: 1.1309866526770214, stop: False, estop cnt: 389
epoch 1550: mean gen loss: 1.818591768779452, mean disc loss: 1.0641656300378224, stop: False, estop cnt: 414
epoch 1575: mean gen loss: 1.3192994859483507, mean disc loss: 1.1325638786194816, stop: False, estop cnt: 439
epoch 1600: mean gen loss: 1.431029789031498, mean disc loss: 1.1553302341037326, stop: False, estop cnt: 464
epoch 1625: mean gen loss: 1.4295746334015378, mean disc loss: 1.1681690518818204, stop: False, estop cnt: 489
epoch 1650: mean gen loss: 1.472423371814546, mean disc loss: 1.1381875900995164, stop: False, estop cnt: 514
epoch 1675: mean gen loss: 1.3357574220687625, mean disc loss: 1.1963881235274056, stop: False, estop cnt: 539
epoch 1700: mean gen loss: 2.119165571909102, mean disc loss: 1.0401869274321056, stop: False, estop cnt: 564
epoch 1725: mean gen loss: 1.2796418931749132, mean disc loss: 1.1984588380843875, stop: False, estop cnt: 589
epoch 1750: mean gen loss: 1.462282453264509, mean disc loss: 1.1458179534427704, stop: False, estop cnt: 614
epoch 1775: mean gen loss: 1.4013410295758928, mean disc loss: 1.2032413785419767, stop: False, estop cnt: 639
epoch 1800: mean gen loss: 1.8006429520864335, mean disc loss: 1.1387399340432787, stop: False, estop cnt: 664
epoch 1825: mean gen loss: 1.7469746423146082, mean disc loss: 1.1024441189236112, stop: False, estop cnt: 689
epoch 1850: mean gen loss: 1.8426899985661582, mean disc loss: 1.1001387096586681, stop: False, estop cnt: 714
epoch 1875: mean gen loss: 1.529631115141369, mean disc loss: 1.1396466209774925, stop: False, estop cnt: 739
epoch 1900: mean gen loss: 1.3791856311616444, mean disc loss: 1.1976354689825148, stop: False, estop cnt: 764
epoch 1925: mean gen loss: 1.5542034573025174, mean disc loss: 1.1473316010974703, stop: False, estop cnt: 789
epoch 1950: mean gen loss: 1.608196561298673, mean disc loss: 1.1532649691142733, stop: False, estop cnt: 814
epoch 1975: mean gen loss: 1.354766361297123, mean disc loss: 1.2060589260525174, stop: False, estop cnt: 7
epoch 2000: mean gen loss: 1.6041793823242188, mean disc loss: 1.1407540941995287, stop: False, estop cnt: 32
epoch 2025: mean gen loss: 1.6113258240714905, mean disc loss: 1.118849860297309, stop: False, estop cnt: 57
epoch 2050: mean gen loss: 1.5889006115141369, mean disc loss: 1.158093770345052, stop: False, estop cnt: 82
epoch 2075: mean gen loss: 1.1280906313941592, mean disc loss: 1.2641022697327629, stop: False, estop cnt: 107
epoch 2100: mean gen loss: 1.263321528359065, mean disc loss: 1.2380415901305184, stop: False, estop cnt: 132
epoch 2125: mean gen loss: 1.7042105538504464, mean disc loss: 1.0968881031823536, stop: False, estop cnt: 157
epoch 2150: mean gen loss: 1.4343258085704984, mean disc loss: 1.173721918984065, stop: False, estop cnt: 182
epoch 2175: mean gen loss: 1.4060797312903026, mean disc loss: 1.2016790480840773, stop: False, estop cnt: 207
epoch 2200: mean gen loss: 1.3623887319413444, mean disc loss: 1.178193834092882, stop: False, estop cnt: 232
epoch 2225: mean gen loss: 1.3237030998108879, mean disc loss: 1.199053688654824, stop: False, estop cnt: 257
epoch 2250: mean gen loss: 1.255095951140873, mean disc loss: 1.2113934471493675, stop: False, estop cnt: 282
epoch 2275: mean gen loss: 1.513050018794953, mean disc loss: 1.1574203249007937, stop: False, estop cnt: 307
epoch 2300: mean gen loss: 1.55085936046782, mean disc loss: 1.1212103707449776, stop: False, estop cnt: 332
epoch 2325: mean gen loss: 1.2352994888547868, mean disc loss: 1.226092989482577, stop: False, estop cnt: 357
epoch 2350: mean gen loss: 1.198822505890377, mean disc loss: 1.2470870245070684, stop: False, estop cnt: 382
epoch 2375: mean gen loss: 1.3365892682756697, mean disc loss: 1.2029145255921379, stop: False, estop cnt: 407
epoch 2400: mean gen loss: 1.4339089772057911, mean disc loss: 1.2012264917767237, stop: False, estop cnt: 432
epoch 2425: mean gen loss: 1.3801475403800842, mean disc loss: 1.190837678455171, stop: False, estop cnt: 457
epoch 2450: mean gen loss: 1.1451099940708704, mean disc loss: 1.2679090954008556, stop: False, estop cnt: 482
epoch 2475: mean gen loss: 1.1994972834511408, mean disc loss: 1.2438997541155135, stop: False, estop cnt: 507
epoch 2500: mean gen loss: 1.1065294780428447, mean disc loss: 1.262543209015377, stop: False, estop cnt: 532
epoch 2525: mean gen loss: 1.2589571513826885, mean disc loss: 1.2449618142748635, stop: False, estop cnt: 557
epoch 2550: mean gen loss: 1.321533687531002, mean disc loss: 1.192488534109933, stop: False, estop cnt: 582
epoch 2575: mean gen loss: 1.3483537946428572, mean disc loss: 1.2131730336991569, stop: False, estop cnt: 607
epoch 2600: mean gen loss: 1.172957768515935, mean disc loss: 1.2144683353484622, stop: False, estop cnt: 632
epoch 2625: mean gen loss: 1.2321067688957092, mean disc loss: 1.218670678517175, stop: False, estop cnt: 657
epoch 2650: mean gen loss: 1.1197867015051464, mean disc loss: 1.251418219672309, stop: False, estop cnt: 682
epoch 2675: mean gen loss: 1.2424728151351687, mean disc loss: 1.22645266093905, stop: False, estop cnt: 14
epoch 2700: mean gen loss: 1.405977037217882, mean disc loss: 1.1853150867280506, stop: False, estop cnt: 39
epoch 2725: mean gen loss: 1.3614652118985615, mean disc loss: 1.206495375860305, stop: False, estop cnt: 64
epoch 2750: mean gen loss: 1.3693294222392733, mean disc loss: 1.1798015776134672, stop: False, estop cnt: 89
epoch 2775: mean gen loss: 1.1490822443886408, mean disc loss: 1.240739489358569, stop: False, estop cnt: 114
epoch 2800: mean gen loss: 1.2996763199094743, mean disc loss: 1.210337926471044, stop: False, estop cnt: 139
epoch 2825: mean gen loss: 1.483689202202691, mean disc loss: 1.1688670809306796, stop: False, estop cnt: 164
epoch 2850: mean gen loss: 1.7806356520879836, mean disc loss: 1.0637306334480408, stop: False, estop cnt: 189
epoch 2875: mean gen loss: 1.3066463167705233, mean disc loss: 1.1941248575846355, stop: False, estop cnt: 214
epoch 2900: mean gen loss: 1.3563794332837302, mean disc loss: 1.1830919053819444, stop: False, estop cnt: 239
epoch 2925: mean gen loss: 1.3945569235181052, mean disc loss: 1.172199915325831, stop: False, estop cnt: 264
epoch 2950: mean gen loss: 1.6389867389012898, mean disc loss: 1.123992193312872, stop: False, estop cnt: 289
epoch 2975: mean gen loss: 1.274310278514075, mean disc loss: 1.1855599539620536, stop: False, estop cnt: 314
epoch 3000: mean gen loss: 1.219954112219432, mean disc loss: 1.1844482421875, stop: False, estop cnt: 339
epoch 3025: mean gen loss: 1.2891053699311756, mean disc loss: 1.1808423239087302, stop: False, estop cnt: 364
epoch 3050: mean gen loss: 1.6990057324606276, mean disc loss: 1.1149866013299852, stop: False, estop cnt: 389
epoch 3075: mean gen loss: 1.223278590611049, mean disc loss: 1.1877206469339037, stop: False, estop cnt: 414
epoch 3100: mean gen loss: 1.1774324689592635, mean disc loss: 1.203657725500682, stop: False, estop cnt: 439
epoch 3125: mean gen loss: 1.4640097239660839, mean disc loss: 1.1248075697157118, stop: False, estop cnt: 464
epoch 3150: mean gen loss: 1.7314837016756572, mean disc loss: 1.0935484871031746, stop: False, estop cnt: 489
epoch 3175: mean gen loss: 1.4085485064794148, mean disc loss: 1.1427561442057292, stop: False, estop cnt: 514
epoch 3200: mean gen loss: 1.2099351428803944, mean disc loss: 1.2031935434492806, stop: False, estop cnt: 539
epoch 3225: mean gen loss: 1.8301285032242063, mean disc loss: 1.0841923062763517, stop: False, estop cnt: 564
epoch 3250: mean gen loss: 1.370682610405816, mean disc loss: 1.1452503507099454, stop: False, estop cnt: 589
epoch 3275: mean gen loss: 1.5058947366381448, mean disc loss: 1.0957745748852927, stop: False, estop cnt: 614
epoch 3300: mean gen loss: 1.6115014212472099, mean disc loss: 1.077863541860429, stop: False, estop cnt: 639
epoch 3325: mean gen loss: 1.6267370799231151, mean disc loss: 1.0922696552579365, stop: False, estop cnt: 664
epoch 3350: mean gen loss: 1.3148389543805803, mean disc loss: 1.1649569556826638, stop: False, estop cnt: 689
epoch 3375: mean gen loss: 1.5121439373682415, mean disc loss: 1.1057283916170635, stop: False, estop cnt: 714
epoch 3400: mean gen loss: 1.6980152432880704, mean disc loss: 1.0578569684709822, stop: False, estop cnt: 739
epoch 3425: mean gen loss: 1.2621185060531375, mean disc loss: 1.192001706077939, stop: False, estop cnt: 764
epoch 3450: mean gen loss: 1.6520425705682664, mean disc loss: 1.0967957027374753, stop: False, estop cnt: 789
epoch 3475: mean gen loss: 1.2660870022243924, mean disc loss: 1.1714248657226562, stop: False, estop cnt: 814
epoch 3500: mean gen loss: 1.4377715095641121, mean disc loss: 1.1041486225430928, stop: False, estop cnt: 839
epoch 3525: mean gen loss: 1.151961190359933, mean disc loss: 1.177339099702381, stop: False, estop cnt: 864
epoch 3550: mean gen loss: 1.4643651568700398, mean disc loss: 1.1120365687779017, stop: False, estop cnt: 889
epoch 3575: mean gen loss: 1.5067418416341145, mean disc loss: 1.1052450755285839, stop: False, estop cnt: 914
epoch 3600: mean gen loss: 1.154616461859809, mean disc loss: 1.1967544555664062, stop: False, estop cnt: 939
epoch 3625: mean gen loss: 1.2519183688693576, mean disc loss: 1.1373306758820065, stop: False, estop cnt: 964
epoch 3650: mean gen loss: 1.263811505030072, mean disc loss: 1.1594447786845858, stop: False, estop cnt: 989
epoch 3675: mean gen loss: 1.5548402089921256, mean disc loss: 1.0920204283699158, stop: False, estop cnt: 1014
epoch 3685: mean gen loss: 1.3102550203838046, mean disc loss: 1.1605891878642733, stop: True, estop cnt: 1024

time to train GAN : 1415.5403338000178

As you can see, early stop was not particularly early. And, looking at the images, there were some good fakes generated a fair bit earlier in the training. If I decide to run the training again, I would look at reducing the e-stop value or try a different method to determine early stop.

There were 149 images saved. So, I will plot 6 or 8 of them to provide a glimpse into how things progressed.

if plot_6:
  # generate plot of 6 of the saved images
  use_img = [25, 100, 200]
  for i in range(1, 3):
    use_img.append(200 + (i * 1150))
  use_img.append(3625)

  plt.figure(tight_layout=True)
  plt.title("Some Sample Training Sessions", fontsize=16)
  plt.axis("off")
  for i in range(6):
    ax = plt.subplot(3, 2, i + 1)
    img = img_dir / f"gan_{use_img[i]}.png"
    # read the image in
    pic = plt.imread(img)
    plt.imshow(pic)
    plt.axis("off")
  plt.show()
6 of the images comparing generator data against the training data for every 25th epoch

Load and Test Generator from File

One last thing before I call this post done. I am going to load the Generator from the file and have it produce a couple of datasets. Which I will also plot.

To make this work, I have put the training code in an if block. The else block loads the Generator model from the appropriate file, sets it to evaluation/generation mode and has it produce the data for a sine curve. I also put the seed fixing code in a suitable if block. Moving it to a suitable location. We don’t want to have a seed specified as it will limit the output produced by the generator when running in evaluation mode.

...  ...
# set seed for reproducibility
if not ld_model:
  torch.manual_seed(73)

... ...

if not ld_model:
... ...
else:
  # load generator from file
  new_genr = torch.jit.load(sv_dir / "gen_sin.pt", map_location=device)
  # set to generation mode
  new_genr.eval()
  noise = torch.randn((batch_sz, 2)).to(device)
  g_curve = new_genr(noise)
  dt_curve = g_curve.detach().cpu().numpy()

  fig = plt.figure()
  plt.plot(ds_train[:, 0], ds_train[:, 1], ls="", marker=".", markersize=3, c="g", label="Real data")
  plt.plot(dt_curve[:, 0], dt_curve[:, 1], ls="", marker=".", markersize=5, c="k", label="Generated samples")
  plt.title(f"Sine Curve Generated by GAN")
  plt.legend()
  plt.show()

And here’s two examples. Pretty hard telling them apart, but they are different.

1st of 2 example curves produced by the Generator in evaluation mode
2nd of 2 example curves produced by the Generator in evaluation mode

Done

That’s it for this one. All in all took me a couple days or three with lots of searching for code and logic help.

May your GANs treat you with more respect than this one did me.

Next time something a little more complex. Perhaps.

Resources