Okay let’s move on to getting the training loop coded and tested.
Training
I started out with just the basic loop and a quick test of 2 epochs. I did have it plot some of the last batch of training and regenerated images at the end of each epoch.
Updated Imports and Module Variables
Firstly, I have expanded imports and added some control variables to the main autoencoder module, autoe.py
. So here’s everything up to the code to get and implement any command line args.
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as T
from tqdm import tqdm
import math
import sys
# print(sys.path)
sys.path.append('../shared_mods')
# print(sys.path)
import config as cfg # type: ignore
import utils as utl # type: ignore
from logger import Logger # type: ignore
from autoencoder import AE_Simple # type: ignore
# debugging
# explicitly raise an error with a stack trace
torch.autograd.set_detect_anomaly(True)
# runtime control for this module
DEBUG = False
trn_model = True
tst_model = False
# test control booleans
tst_clp = False
tst_get_data = False
Training Loop
As you perhaps noticed, the training loop will be in an if
block. Once I got the simple test working I added code to save checkpoints and the list of loss values. But here’s the first bit.
if trn_model:
# instantiate autoencoder
img_wd, img_ht = trn_set[0][0].shape[1], trn_set[0][0].shape[2]
sz_inp = img_wd * img_ht
isz_log2 = int(math.log2(sz_inp))
ls_mid = int(2**(isz_log2 - 1))
ls_fin = 16
aec = AE_Simple(sz_inp, ls_mid, ls_fin).to(cfg.device)
print(f"\n{aec}")
f_loss = nn.MSELoss()
opt = optim.Adam(aec.parameters(), lr=0.0005)
outputs = []
losses = []
for epoch in range(cfg.epochs):
l_spc, l_lbl = [], []
for (img, lbl) in tqdm(trn_ldr, desc=f"epoch {epoch + 1}"):
# flatten the image tensors to
img = img.reshape(-1, sz_inp).to(cfg.device)
# get autoencoder outputs
regen, lspc = aec(img)
# calculating loss and update the model
loss = f_loss(regen, img)
opt.zero_grad()
loss.backward()
opt.step()
# Storing the losses in a list for plotting
t_loss = loss.detach().cpu().item()
losses.append(t_loss)
l_spc.extend(lspc)
l_lbl.extend(lbl)
# save plots of 16 images from traing data and regenerated data for last iteration
img_rs = img.reshape(cfg.batch_sz, 1, 28, 28)
rgn_rs = regen.reshape(cfg.batch_sz, 1, 28, 28)
s_ndx = int(cfg.rng.integers(0, 16, 1)[0])
e_ndx = s_ndx + 16
utl.image_grid(img_rs[s_ndx:e_ndx], 8, i_show=False, epoch=epoch, b_sz=cfg.batch_sz, img_cl='trn')
utl.image_grid(rgn_rs[s_ndx:e_ndx], 8, i_show=False, epoch=epoch, b_sz=cfg.batch_sz, img_cl='regen')
plt.xlabel('Iterations')
plt.ylabel('Loss')
# Plotting the last 100 values
plt.plot(losses[-100:])
plt.show()
And the command line output.
(mclp-3.12) PS F:\learn\mcl_pytorch\proj7> python autoe.py -rn rek_1 -bs 32 -ep 2
{'run_nm': 'rek_1', 'dataset_nm': 'no_nm', 'sv_img_cyc': 150, 'sv_chk_cyc': 50, 'resume': False, 'start_ep': 0, 'epochs': 5, 'batch_sz': 32, 'num_res_blks': 9, 'x_disc': 1, 'x_genr': 1, 'x_eps': 0, 'use_lrs': False, 'lrs_unit': 'batch', 'lrs_eps': 5, 'lrs_init': 0.01, 'lrs_steps': 25, 'lrs_wmup': 0}
image and checkpoint directories created: runs\rek_1_img & runs\rek_1_sv
(e_init): Linear(in_features=784, out_features=256, bias=True)
(encoded): Linear(in_features=256, out_features=16, bias=True)
(d_init): Linear(in_features=16, out_features=256, bias=True)
(decode): Linear(in_features=256, out_features=784, bias=True)
)
epoch 1: 100%|████████████████████████████████████████████████████████████████████| 1875/1875 [00:18<00:00, 104.04it/s]
epoch 2: 100%|█████████████████████████████████████████████████████████████████████| 1875/1875 [00:19<00:00, 95.34it/s]
And here’s the images saved after the last/second epoch of training.
Pretty amazing for only two epochs of training. I won’t bother including an image of a plot of the losses, not particularly meaningful.
Add Code to Save Checkpoints and Losses
I wanted to test the model, so I added the code to save model checkpoints and loss data. But I had to do a small refactor on the utils.sv_chkpt()
. I wanted to save the JIT script version of the model along with the full checkpoint saving the model states. I had commented out the code for the previous project as I couldn’t save the discriminators that way (due to how I coded them). So a new parameter, do_jit=False
and move the JIT save code into a suitable if
block.
After running a model evaluation test, I added the logger code for saving and plotting losses. That said, I am going to show it now with the other changes mentioned above.
... ...
if trn_model:
# set up logger for losses
loss_nms = ["aec"]
lgr_loss = Logger(cfg.run_nm, cfg.sv_chk_cyc, loss_nms)
print(cfg.device)
# set seed for repeatability
... ...
utl.image_grid(rgn_rs[s_ndx:e_ndx], 8, i_show=False, epoch=epoch, b_sz=cfg.batch_sz, img_cl='regen')
utl.sv_chkpt(cfg.run_nm, epoch, aec, opt, None, cfg.batch_sz, cfg.sv_dir/f"AE_Simple_{epoch}.pt", do_jit=True)
# training epochs complete, save last batch of losses
all_losses = {"aec": losses}
lgr_loss.log_losses(all_losses)
# save losses to file and show plot
lgr_loss.to_file(cfg.sv_dir, epoch, 1875*cfg.epochs)
lgr_loss.plot_losses()
... ...
# removed the earlier loss plotting code.
One thing I haven’t mentioned is how little cpu and gpu utilization this simple model requires. The fans don’t even kick up when training the model.
Model Evaluation
Okay, I wanted to see how the model, when trained for two epochs, would handle some images from the test dataset. So I re-ran the training loop after adding the code to save checkpoints and such. Then loaded the saved model and ran a batch of test images through it. Another if
block.
if tst_model:
# instantiate model for jitscript file
fl_pth = cfg.sv_dir/"AE_Simple_jitscript_32_1.pt"
aec = torch.jit.load(fl_pth).to(cfg.device)
aec.eval()
print("\n", aec)
# get batch of images from test set, 5th batch in fact
dl_iter = iter(tst_ldr)
for _ in range(5):
tst_img, _ = next(dl_iter)
tst_img = tst_img.reshape(-1, 28*28).to(cfg.device)
# get output of autoencoder
rgn_img, ls = aec(tst_img)
img_rs = tst_img.reshape(cfg.batch_sz, 1, 28, 28)
rgn_rs = rgn_img.reshape(cfg.batch_sz, 1, 28, 28)
utl.image_grid(img_rs, 8, i_show=False, epoch=1, b_sz=cfg.batch_sz, img_cl='tst_img')
utl.image_grid(rgn_rs, 8, i_show=False, epoch=1, b_sz=cfg.batch_sz, img_cl='tst_rgn')
Here’s the terminal output.
(mclp-3.12) PS F:\learn\mcl_pytorch\proj7> python autoe.py -rn rek_1 -bs 32
{'run_nm': 'rek_1', 'dataset_nm': 'no_nm', 'sv_img_cyc': 150, 'sv_chk_cyc': 50, 'resume': False, 'start_ep': 0, 'epochs': 5, 'batch_sz': 32, 'num_res_blks': 9, 'x_disc': 1, 'x_genr': 1, 'x_eps': 0, 'use_lrs': False, 'lrs_unit': 'batch', 'lrs_eps': 5, 'lrs_init': 0.01, 'lrs_steps': 25, 'lrs_wmup': 0}
image and checkpoint directories created: runs\rek_1_img & runs\rek_1_sv
RecursiveScriptModule(
original_name=AE_Simple
(e_init): RecursiveScriptModule(original_name=Linear)
(encoded): RecursiveScriptModule(original_name=Linear)
(d_init): RecursiveScriptModule(original_name=Linear)
(decode): RecursiveScriptModule(original_name=Linear)
)
And the images.
Once again, not too shabby.
Training Run for 5 Epochs
I figured I should train for a few more epochs. I originally was thinking of 10. But since the model seemed to be training so well, I decided to just do a five epoch training run.
(mclp-3.12) PS F:\learn\mcl_pytorch\proj7> python autoe.py -rn rek_2 -bs 32 -ep 5
{'run_nm': 'rek_2', 'dataset_nm': 'no_nm', 'sv_img_cyc': 150, 'sv_chk_cyc': 50, 'resume': False, 'start_ep': 0, 'epochs': 5, 'batch_sz': 32, 'num_res_blks': 9, 'x_disc': 1, 'x_genr': 1, 'x_eps': 0, 'use_lrs': False, 'lrs_unit': 'batch', 'lrs_eps': 5, 'lrs_init': 0.01, 'lrs_steps': 25, 'lrs_wmup': 0}
image and checkpoint directories created: runs\rek_2_img & runs\rek_2_sv
AE_Simple(
(e_init): Linear(in_features=784, out_features=256, bias=True)
(encoded): Linear(in_features=256, out_features=16, bias=True)
(d_init): Linear(in_features=16, out_features=256, bias=True)
(decode): Linear(in_features=256, out_features=784, bias=True)
)
epoch 1: 100%|████████████████████████████████████████████████████████████████████| 1875/1875 [00:18<00:00, 104.16it/s]
epoch 2: 100%|█████████████████████████████████████████████████████████████████████| 1875/1875 [00:19<00:00, 96.51it/s]
epoch 3: 100%|█████████████████████████████████████████████████████████████████████| 1875/1875 [00:19<00:00, 96.57it/s]
epoch 4: 100%|█████████████████████████████████████████████████████████████████████| 1875/1875 [00:19<00:00, 95.93it/s]
epoch 5: 100%|█████████████████████████████████████████████████████████████████████| 1875/1875 [00:19<00:00, 96.19it/s]
Regenerated images are really quite close to the ones from the training set. Won’t bother with a model evaluation test.
But here’s the plot of the losses over those many iterations.
Wow, 0 to 100 mph in no time. Then almost forever to get to 125 mph or so.
Fini!
I think that’s it for this one. The plan for next time is to generate a 2D plot of the latent space for the model when run against the test dataset. Apparently the technique I am going to try can be very compute and memory intensive. Time will tell.
Until then, I hope you found it rather satisfying coding a simple autoencoder.