Okay, let’s get on with this project.

I am going to start by setting up the bits and pieces for the loss function. Then on to creating batches and adding the condition tensor to the image or noise tensor for input to the networks. Then on to the code to actually get the training done. Hopefully all in this post.

Wasserstein Distance with Gradient Penalty

Our loss function for the critic looks like the following.

$$ {critic\_value(real)} - {critic\_value(fake)} + ({weight} * {gradient\_penalty})$$

What that basically says is if the image is a fake the critic should give a low evaluation. If it is real, give a high evaluation. Finally minimize the gradient penalty term. The weight is a constant specifying how much penalty to assign to gradient norms straying from the value 1. In the original paper the authors set that weight to \( \lambda = 10 \).

Now, I am not going to say I fully understand how this code works. It is more or less copied from other sources. I certainly don’t currently have the knowledge to sort out the stuff related to calculating the gradients and their norms. That will hopefully improve with further reading and research.

We will not be using a function to get the Wasserstein distance. It is a pretty straightforward difference between two values in our case.

I saw implementations where there was a separate function to get the gradient and another to calculate the penalty. Others combined it all in one function/method. Some passed in epsilon, the values to use to interpolate between the real and fake images. Others calculated those values within the function/method. I am going to go with the do it all in one place approach. Not to mention that there are a number of ways to generate the epsilon tensor.

Okay here’s the function code.

  def g_penalty(critic, real, fake):
    b_sz = real.size()[0]
    # generate epsilon values for interpolating the real and fake images
    e_shape = [1 for _ in real.size]
    e_shape[0] = b_sz
    eps = torch.rand(e_shape, device=device).requires_grad_()
    eps = eps.expand_as(real)
    mx_imgs = (real * eps) + (fake * (1-eps))
    # get gradients
    c_scores = critic(mx_imgs)
    gradient = torch.autograd.grad(
      inputs=mx_imgs,
      outputs=c_scores,
      grad_outputs=torch.ones_like(c_scores),
      create_graph=True,
      retain_graph=True)[0]
    # Gradients have same shape as real and fake tensors, so flatten to easily take norm
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    penalty = torch.mean((gradient_norm - 1) ** 2)
    return penalty

Decided to write some test code. I have no idea how to ensure the function is producing the correct output. Basically all I am doing is making sure the code runs without errors. I am only using a batch size of 1 as I don’t really need more images to test the code.

def test_gradient_penalty():
    img = Image.open(data_dir/'NoG'/'face-1.png')
    img = transform(img).to(device)
    # note change in class label
    label = 1
    onehot=torch.zeros((2))
    onehot[label] = 1
    channels = torch.zeros((2, img_sz, img_sz)).to(device)
    channels[label, :, :] = 1
    img_and_label = torch.cat([img, channels], dim=0)

    fake_img = torch.rand(3, 256, 256).to(device)
    fake_and_label = torch.cat([fake_img, channels], dim=0)

    t_gp = g_penalty(critic, img_and_label, fake_and_label)
    print(f"t_gp: {t_gp}")

  test_gradient_penalty()

And no errors, with the following output. Actually, there were some bugs, all due to typos. So I am not going to discuss that here. And, also, once again forgetting to send the fake image to the GPU.

(mclp-3.12) PS F:\learn\mcl_pytorch\chap5> python wgan-wp_g_ng.py
t_gp: 31.70458984375

And if we multiply that by 10, seems like an awfully big penalty. Fingers crossed. But, I expect the large value is due to the fact that the real image was normalized to a different range of values than the fake image contained.

Data Batches

This process is also borrowed from other people’s code and/or tutorials.

The whole dataset gets modified to start, so that we don’t need to modify each batch in each epoch. That would really be costly in terms of cpu/gpu cycles. The modified dataset is then used to instantiate the data loader.

By modified I mean the data labels get concatentated with the images. Well, one-hot encoded labels. That is what our critic network expects. We will still need to do that for every generator output sent to the critic during training.—no way to avoid that repetitive process.

Because I am coding the module incrementally, I don’t want to keep generating the modified image data every time I run the module. It is a lengthy process. So I have put the code that produces the modified dataset in an if block. At the end of which I save the dataset to file. In the else block I load the dataset from the file.

I likely should have shown the initial dataset transform definition in an earlier post, but don’t think I did. So will include it here.

  transform = trf.Compose([
    trf.Resize((img_sz, img_sz)),
    trf.ToTensor(),
    trf.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])]) 

... ...

  # build or load dataset for training
  if proc_raw:
    # load images along with labels using 
    raw_set = torchvision.datasets.ImageFolder(
      root=data_dir,
      transform=transform
    )
    print(raw_set)
  
    # create complete set of tensors for all images combined with their one-hot encoded labels
    # include original image, label, and one-hot vector in tuple with the combined tensor
    # that information will be needed later
    st_tm = time.perf_counter()
    imgs_lbls = []
    for i, (img, lbl) in enumerate(raw_set):
      one_hot = torch.zeros((nbr_classes))
      one_hot[lbl] = 1
      # lbl one_hot tensor needs to be same shape as image
      lbl_vec = torch.zeros((nbr_classes, img_sz, img_sz))
      lbl_vec[lbl, :, :] = 1
      img_w_lbl = torch.cat([img, lbl_vec], dim=0)
      imgs_lbls.append((img, lbl, one_hot, img_w_lbl))
    nd_tm = time.perf_counter()
    print(f"\ntime to merge images and labels: {nd_tm - st_tm}")  
    torch.save(imgs_lbls, r"./data/img_and_lbl.pt")
    imgs_lbls = None
  else:
    imgs_cmplt = torch.load(r"./data/img_and_lbl.pt")

The terminal output follows.

(mclp-3.12) PS F:\learn\mcl_pytorch\chap5> python wgan-wp_g_ng.py
Dataset ImageFolder
    Number of datapoints: 4448
    Root location: data\glasses
    StandardTransform
Transform: Compose(
               Resize(size=(256, 256), interpolation=bilinear, max_size=None, antialias=True)
               ToTensor()
               Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
           )

time to merge images and labels: 149.1113638000097

It really felt like it took much longer than 149 seconds. And, of course, proc_raw is now False.

Once I had the file and was loading it rather than producing the revised dataset, I instantiated the DataLoader.

  # instantiate dataloader using new dataset
  img_ldr = torch.utils.data.DataLoader(
    imgs_cmplt, batch_size=batch_sz, shuffle=True)

Test Loading from File

I decided I should make sure I am getting what I expect from tne DataLoader using the dataset loaded from file.

I will print out some info about the shape/size of the full dataset. Next I will obtain a batch of data, assigning each element of the tuple to a separate variable. Print some information regarding the shape/size of the variables. Then extract and plot the images from the tensor with the combined images and labels. I will use the label list (2nd element of batch tuple) to annotate the images.

  def tst_inl_load(imgs_cmplt):
    print(len(imgs_cmplt), len(imgs_cmplt[0]), imgs_cmplt[0][-1].shape)

    # get a batch from the dataloader
    i_img, i_lbl, i_one_hot, i_iwl = next(iter(img_ldr))
    print("n", len(i_img), len(i_lbl), i_one_hot.shape, i_iwl.shape)
    print("\n", i_lbl)
    int_oh = [int(oh) for oh in i_one_hot[:, 1]]
    print("      ", int_oh)

    # Let's see if I can extract images in i_iwl and display
    # convert image data to range 0 - 1
    img = i_iwl[:, :3, :, :] / 2 + 0.5
    img_grid_lbl(img, i_lbl, 4, i_show=True, epoch=0)


  if tst_inl:
    tst_inl_load(imgs_cmplt)
(mclp-3.12) PS F:\learn\mcl_pytorch\chap5> python wgan-gp_g_ng.py
4448 4 torch.Size([5, 256, 256])

16 16 torch.Size([16, 2]) torch.Size([16, 5, 256, 256])

tensor([1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1])
       [1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1]

The test image follows.

sample of images produced when testing the dataloader code when loading the revised dataset from file

Probably a little too simple of a test. But better than nothing at all and it does seem to show things work as expected.

Done

Feel that this post has gotten to be about long enough. And, the coding took me, over a period of a few days, a little longer than I expected. So, I am going to call it finished. Building the training loop will have to wait for another day.

I found experimenting with the test code somewhat enlightening. Both from how things sort of work and what one can actually extract from the various bits of data.

Until next time, may you also find your time coding to be enlightening. Rather than disheartening.

Resources