Well! It has been over two weeks since I completed the draft of the last “Side Trip” post (it is July 30, 2024 as I start this draft). Guess I was more nervous about starting work on a CycleGAN (Cycle-Consistent Generative Adversarial Network) than I thought. It is a touch more convoluted than previous GANs. In fact it is essentially two GANs in one. Along with some additional errors to be measured and used to train the generator networks appropriately.
And another delay, after almost completing the draft for this post, I decided to step back, postpone this one and write the previous post on my approach to using a separate module for project global variables. As someone once pointed out, no plan can survive without change.
CycleGAN
What is a CycleGAN you ask?
The GAN projects to-date have all been trying to produce images that are indistinguishable from the images in the training data. I.E. items of clothing, anime faces, people with or without glasses. What a CycleGAN proposes to do is take images from one domain (e.g. horses) and convert them to images in another domain (e.g. zebras). Or a photograph into a watercolour. And, of course, vice versa.
We could do this by having the exact, or almost exact, paired images in the two domains. A CycleGAN is designed to train the generator without paired examples. That is, the images in each domain can be completely different from each other.
Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
Our goal is to learn a mapping \(G: X \rightarrow Y\) such that the distribution of images from \(G(X)\) is indistinguishable from the distribution \(Y\) using an adversarial loss. Because this mapping is highly under-constrained, we couple it with an inverse mapping \(F: Y \rightarrow X\) and introduce a cycle consistency loss to push \(F(G(X)) \approx X\) (and vice versa).
Jun-Yan Zhu, Taesung Park, Phillip Isola, Alexei A. Efros
Cycle Consistency
We will code two GANs, one for each domain. Effectively, the two generators of the GANs are the generator of the CycleGAN, and the two discriminators are the discriminator of the CycleGAN. This compound model is then trained to generate images from the source domain that are similar to images from the target domain and vice versa.
The descriptor “cycle” is the key to the process. Images from the source domain are run through the target domain generator. The result is then run through the source domain generator. The initial image and the final image, if everything is working to perfection should be the same. Ditto for the opposite direction. This is referred to as cycle consistency. An additional loss is added to the training cycle to ensure cycle consistency—in both directions.
Identity Mapping
CycleGANs have experienced some issues. One is the darkening of colours. Another is over-training. One of the modifications to deal with this is the introduction of Identity Loss.
This loss can regularize the generator to be near an identity mapping when real samples of the target domain are provided. If something already looks like from the target domain, you should not map it into a different image.
Jun-Yan Zhu
In other words, if we present an image of a zebra to the zebra generator we don’t want it to generate a different zebra image.
At present I plan to include this loss in my network model. But that may depend on how much effort that will take. Many of the tutorials I looked at didn’t do so. But, I expect it provides significant benefit to the training process. If, likely, at the cost of training time.
Total Loss
The discriminator loss is a standard adversarial loss we have seen in our previous GAN projects.
The generator loss on the other hand, will involve the adversarial loss, the cyclical consistency loss and the identity loss. And it will involve those losses for both cycle directions if appropriate.
Pretty brief coverage of a CycleGAN. But if you are interested in the details there are plenty of tutorials on the web.
The Python Modules
My goal is to create something tidier and DRYer than my usual one big module approach. So I expect it is going to take me quite some time to code this project. As well, there are going to be a couple new features in the some of the networks. For example, residual blocks.
I plan to put various related portions of the code in their own modules. E.G. utility functions (plotting and such), model classes, training code, etc. Don’t know how successful I will be, but it is considered a best practice.
Have started (2024.07.31) on some of the code related to loading data and presenting to the networks. That is proving a touch more finicky than I expected. Have a feeling this project is going to take a goodly amount of time to get working. And, as usual, trouble going from tensors to something matplotlib will accept. But did sort that eventually.
Am also thinking this project will involve more than a couple blog posts.
I started by creating cyc_gan.py
. This will be my primary module for training the cycleGAN networks. I added a bunch of imports, the command line parameter/global variables related code from the last post, a number of global variables used to control program execution and some others to be used specifically for coding/training/using the model networks.
cyc_gan.py Initial Code
# ../proj6/cyc_gan.py
# Ver 0.1.0: 2024.07.31, rek,
# first attempt at coding a cycle gan, rek, 2024.07.31
# at this point I am not yet sure of the datasets I will use
import math, time
from pathlib import Path
from time import localtime, strftime
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as trf
from torchvision.utils import make_grid
# from config import rng, trfs, trc_seed
import config as cfg
from datasets import LoadData
from utils import image_grid, tensor2image
# runtime control for this module,
# copied from other projects, may not use them all, may add others
tst_clp = False
trn_model = True
trn_prev = False
tst_get_data = False
plot_ds = False
prn_critic = False
tst_critic = False
prn_genr = False
tst_genr = False
proc_raw = False
tst_inl = False
tst_t_btch = False
tst_model = False
tst_mod_2 = False
# will need this later
device = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')
# debugging
# explicitly raise an error with a stack trace
torch.autograd.set_detect_anomaly(True)
# get command line parameters, update globals, create project sub-directories
cl_args = cfg.get_cl_args()
if tst_clp:
print(cl_args)
print(f"before updt -> run_nm : {cfg.run_nm}, epochs: {cfg.epochs}")
cfg.updt_cl_args(cl_args)
if tst_clp:
print(f"after updt -> run_nm : {cfg.run_nm}, epochs: {cfg.epochs}")
cfg.mk_dirs()
# define the classes/functions we will use during training or evaluation
Data
Okay, I downloaded the Horse2zebra Dataset from kaggle (zip).
But, because for each iteration we want to load images from two directories, we can’t simply use torchvision.datasets.ImageFolder
as we did in the previous project. We have to write our own class to get the data to provide to the DataLoader
class when we instantiate it. A bit of work sorting that.
LoadData Class and dataset.py
So a new module: dataset.py
. In the process of coding the LoadData
class, I also ended up creating a utils.py
module. And of course there is the config.py
module created in the last post.
I used an if __name__ == "__main__"
block for my module test code. (Should really learn to use a testing package.) The Python style guide says all imports should be at the top of the file. So I used the same if block at the top of the module for the imports only used for testing.
torch.utils.data.Dataset
is an abstract class representing a dataset. Your custom dataset should inheritDataset
and override the following methods:
- __len__ so that len(dataset) returns the size of the dataset.
- __getitem__ to support the indexing such that
dataset[i]
can be used to get \(i\)th sample.
And, we will obviously need an __init__
to instantiate the class with whatever data it needs to execute the two methods above.
Later in my coding process, I got the following error while testing the dataloader:
RuntimeError: output with shape [1, 256, 256] doesn't match the broadcast shape [3, 256, 256]
Looks like there is at least one grey-scale image in the datasets. So, I chained in a method to convert all images to RGB.
# ../proj6/datasets.py
# Ver 0.1.0: 2024.07.31, rek,
# we need to load images from two different directories for each iteration
# so need to create our own class for loading images from disk
import pathlib, random
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as trf
import config as cfg
if __name__ == "__main__":
# These imports not really needed when module loaded as import,
# only using for testing.
import matplotlib.pyplot as plt
from utils import tensor2image
from pathlib import Path
class LoadData(Dataset):
def __init__(self, root_A, root_B, trfs=None, unaligned=False, mode='train'):
self.transform = trf.Compose(trfs)
self.unaligned = unaligned
src_A = pathlib.Path(root_A)
src_B = pathlib.Path(root_B)
self.files_A = sorted(list(src_A.iterdir()))
self.files_B = sorted(list(src_B.iterdir()))
def __getitem__(self, index):
item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)]).convert("RGB"))
if self.unaligned:
item_B = self.transform(Image.open(self.files_B[cfg.rng.integers(0, len(self.files_B))]).convert("RGB"))
else:
item_B = self.transform(Image.open(self.files_B[index % len(self.files_B)]).convert("RGB"))
return item_A, item_B
def __len__(self):
return max(len(self.files_A), len(self.files_B))
if __name__ == "__main__":
print("Testing LoadData class")
get_data = LoadData(cfg.ds_A, cfg.ds_B, trfs=cfg.trfs)
print(f"test LoadData.__len__(): {len(get_data)}")
print(f"length of files_A (horses): {len(get_data.files_A)}, length of files_B (zebras): {len(get_data.files_B)}")
print(f"horse[0]: {get_data.files_A[0]}, zebra[0]: {get_data.files_B[0]}")
print(f"test LoadData.__getitem__()")
i_data_a, i_data_b = get_data[0]
print(f"\treturned by __getitem__: {i_data_a.shape} & {i_data_a.shape}")
# this also effectively tests utils.tensor2image()
imgplot = plt.imshow(tensor2image(i_data_a))
plt.show()
imgplot = plt.imshow(tensor2image(i_data_b))
plt.show()
And at this point the utilities module is as follows.
# ../proj6/utils.py
# Ver 0.1.0: 2024.07.31, rek,
import numpy as np
def tensor2image(tensor):
image = tensor.detach().cpu().float().permute(1, 2, 0).numpy() / 2 + 0.5
if image.shape[0] == 1:
image = np.tile(image, (3,1,2))
return image
And here’s the terminal output. I won’t bother displaying the images. Will do that a little further on.
((mclp-3.12) PS F:\learn\mcl_pytorch\proj6> python datasets.py
Testing LoadData class
test LoadData.__len__(): 1334
length of files_A (horses): 1067, length of files_B (zebras): 1334
horse[0]: data\trainA\n02381460_1001.jpg, zebra[0]: data\trainB\n02391049_10007.jpg
test LoadData.__getitem__()
returned by __getitem__: torch.Size([3, 256, 256]) & torch.Size([3, 256, 256])
DataLoader
Okay, let’s use that new LoadData
class to set up a DataLoader
, in cyc_gan.py
, to provide batches of images of both types to use for training our networks. Then print a batch of images to make sure things work as expected.
... ...
if trn_model:
# set seed for repeatability
torch.manual_seed(cfg.pt_seed)
# let's define the classes/functions we will only use during training
# let's get our data set up
get_data = LoadData(cfg.ds_A, cfg.ds_B, trfs=cfg.trfs)
d_ldr = DataLoader(
get_data,
batch_size=cfg.batch_sz,
# num_workers=0,
shuffle=True,
)
if tst_get_data:
a_btch = next(iter(d_ldr))
print(len(a_btch))
print(a_btch[0].shape, a_btch[1].shape)
image_grid(a_btch[0], 4, i_show=True, epoch=0, b_sz=cfg.batch_sz, img_cl='A')
image_grid(a_btch[1], 4, i_show=True, epoch=0, b_sz=cfg.batch_sz, img_cl='b')
And, with tst_get_data
equal to True
the terminal output and images were as follows.
(mclp-3.12) PS F:\learn\mcl_pytorch\proj6> python cyc_gan.py
2
torch.Size([16, 3, 256, 256]) torch.Size([16, 3, 256, 256])
Other Refactored Modules
I refactored the function image_grid
from an earlier version. Here’s that new code for the modified utils.py
module. I will let you sort the imports. Note the use of a global from the config module. May eventually need to alter the final name construction.
# ../proj6/utils.py
... ...
def image_grid(images, ncol, i_show=True, epoch=0, b_sz=16, img_cl='A'):
image_grid = make_grid(images, ncol, normalize=True) # Make images into a grid
image_grid = image_grid.permute(1, 2, 0) # Move channel to the last dim
image_grid = image_grid.cpu().numpy() # Convert to Numpy
plt.imshow(image_grid)
plt.xticks([])
plt.yticks([])
if i_show:
plt.show()
else:
plt.savefig(cfg.img_dir / f"{img_cl}_{epoch}_{b_sz}.png")
plt.close()
... ...
Done
As things have changed a bit over the last while, I think it would be a good idea to bring this post to an end.
Next post I hope to move on to coding the discriminator and generator networks. Remember each one will be instantiated twice, so class definitions are in order. As is, at least, one new module. Though might look at a separate module for each network type.
Until next time, stay focused and keep things simple.
Resources
- Datasets & DataLoaders
- Horse2zebra Dataset
- Understanding transform.Normalize