Okay, time to look at building and training the originally intended multi-category classification model.

Initial Setup Code

I am going to start a new python module but the imports, initial variables and data loading code will be the same as the binary classification module. There will, however, subsequently be a number of differences.

# multi_cat.py
#  - train multi-category classification model for Fashion-MNIST dataset
# Ver 0.1.0: 2024.03.18, rek, get started figuring this out

import time
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn  as nn
import torchvision
import torchvision.transforms as trf

# set seed for reproducibility
torch.manual_seed(73)

# will need this later
device = "cuda" if torch.cuda.is_available() else "cpu"

# instantiate our data transform for the image dataset
# generate tensors (vals 0-1), centre and normalize tensors
transform = trf.Compose([
  trf.ToTensor(),
  trf.Normalize((0.5), (0.5))
])

# Create datasets for training & testing, download if necessary
ds_train = torchvision.datasets.FashionMNIST(
  root="./data",
  train=True,
  download=True,
  transform=transform
)
ds_test = torchvision.datasets.FashionMNIST(
  root="./data",
  train=False,
  download=True,
  transform=transform
)

# Class labels, in dataset they are digits 0-9
categories = ('t-shirt/top', 'trouser', 'pullover', 'dress', 'coat',
           'sandal', 'shirt', 'sneaker', 'bag', 'ankle Boot')

Validation Set

One thing that will be different this time is we will be using a validation dataset. It will be used to determine the ‘optimal’ number of epochs to use in training. Though it could also be used to help determine the best value for any number of hyperparameters. After each epoch the current model state will be used to determine its accuracy against the validation set. Once that accuracy has not improved for some specified number of epochs, training will be ended.

So we will need to split the training set into a training set and a validation set. I will use an 80-20 split. We will then need to generate 3 data loaders; one each for training, validation and test datasets.

# create validation dataset
ds_train, ds_val = torch.utils.data.random_split(ds_train,[0.8, 0.2])
if True:
  # Show split sizes
  print(f"Training set has {len(ds_train)} instances")
  print(f"Validation set has {len(ds_val)} instances")
(mclp-3.12) PS F:\learn\mcl_pytorch\chap2> python multi-cat.py
Training set has 48000 instances
Validation set has 12000 instances

Data Loaders

Not much to say.

# Create data loaders for our datasets; shuffle for training and testing
batch_sz = 64
train_loader = torch.utils.data.DataLoader(ds_train, batch_size=batch_sz, shuffle=True)
val_loader = torch.utils.data.DataLoader(ds_val, batch_size=batch_sz, shuffle=True)
test_loader = torch.utils.data.DataLoader(ds_test, batch_size=batch_sz, shuffle=True)

Learning Rate, Optimizer, Loss Function

The learning rate and optimizer will remain the same as those used for the binary classification model.

However, we can no longer use the binary cross-entropy loss. For multi-categorical classification the categorical cross-entropy loss is recommended loss function.

# set learning rate, optimizer and loss function
lr = 0.001
optimizer = torch.optim.Adam(bc_model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()

Early Stop

To make things tidier I am going to encapsulate the variables and stop evaluation function in a class. It’s only purpose is to keep track of the minimum loss value of the model against the validation set. Reset that value as necessary. And it will also keep track of the number of epochs since that minimum value was generated. If that number exceeds some specified value the class method will return True. Otherwise it will return False.

The value passed to the class method will be generated by another function that evaluates the current model against the validation set.

# early stop class
class EarlyStop:
  def __init__(self, e_wait=8):
    self.e_wait = e_wait   # stop if notbttr >= e_wait
    self.notbttr = 0    # number of epochs without improvement
    # last best loss, start with big number as looking for min
    self.min_loss = float('inf')
  
  # vs_loss: current value of validation set average loss
  def is_stop(self, vs_loss):
    if vs_loss < self.min_loss:
      self.min_loss = vs_loss
      self.notbttr = 0
    else:
      self.notbttr += 1
    return self.notbttr >= self.e_wait


# some model parameters, hyperparameters?
max_ep = 100    # maximum number of epochs for training
do_p = 0.16     # Dropout probability
e_stop = 8      # number of epochs for early stop

# instantiate class
chk_stop = EarlyStop(e_stop)

Model

We will need to change the final output activation function. Sigmoid is only valid for binary classification. For multi-category classification softmax is the go-to. It will squeeze the outputs into the range 0-1. And, within the limits of floating point arithmetic on a computer, those 10, in our case, numbers will add up to 1.0.

They are effectively the probability that the given category matches that for the image being assessed. Consequently they need to add to 1.0.

And since we will now have 10 outputs, I am going to increase the number of inputs in the final layer to 64. In general, for neural networks we want to gradually increase or decrease the number of neurons from one layer to the next. Since the output has gone from 1 neuron to 10, it is likely best to increase the number of input neurons.

# define model
mc_model = nn.Sequential(
  nn.Linear(28*28, 256),
  nn.ReLU(),
  nn.Linear(256, 128),
  nn.ReLU(),
  nn.Linear(128, 64),
  nn.ReLU(),
  nn.Linear(64, 10),
  nn.Dropout(p=do_p),
  nn.Softmax(dim=1)
).to(device)

Run Epoch Function

Rather than include the epoch training code in our model training loop, we’ll put it a separate function. Tidier and considered a better programming practice. Pretty similar to that used for the binary classification model in the last post.

# function to execute single training epoch on current model state
def ex_epoch():
  tot_loss = 0
  for i, (imgs, lbls) in enumerate(train_loader):
    imgs = imgs.reshape(-1, 28*28).to(device)
    lbls = lbls.reshape(-1, ).to(device)
    preds = mc_model(imgs)
    loss = loss_fn(preds, lbls)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    tot_loss += loss
  # i should now be the length of validation set batch
  return tot_loss / i

Validation Function

This function uses the current model to predict the categories on the validation set. Then applies our specific loss function and returns the average value of the loss for each prediction. Pretty much a repeat of the above, excluding the training specific steps.

# function to calculate average loss of current model against validation set
def val_loss():
  v_loss = 0
  for i, (imgs, lbls) in enumerate(val_loader):
    # flatten image tensor, sent to gpu
    imgs = imgs.reshape(-1, 28*28).to(device)
    # ditto labels, sent to gpu
    lbls = lbls.reshape(-1, ).to(device)
    preds = mc_model(imgs)
    loss = loss_fn(preds, lbls)
    v_loss += loss
  # i should now be the length of validation set batch
  return v_loss / i

Test Model

To test the model we need to make sure it is in evaluation mode rather than training mode. If I was planning to train it futher I am led to believe I would have to add mc_model.train() before doing any futher training. I have not yet tested that option.

And, we want to make sure we are using the GPU to test the model.

# test the model
mc_model.eval()
mc_model.to(device)
st_tm = time.perf_counter()
results = []
for imgs, lbls in test_loader:
  imgs = imgs.reshape(-1, 28*28).to(device)
  lbls = lbls.reshape(-1, ).to(device)
  preds = mc_model(imgs)
  # get index of category with max probability
  pred_l = torch.argmax(preds, dim=1)
  correct = (pred_l == lbls)
  results.append(correct.detach().cpu().numpy().mean())

accuracy = np.array(results).mean()
nd_tm = time.perf_counter()
print(f"\nprediction accuracy is {accuracy}")
print(f"\ntime to test model: {nd_tm - st_tm}")

Train Model

The testing code is run immediately after the training of the model is stopped.

No Early Stop & Dropoout(p=0.25)

(mclp-3.12) PS F:\learn\mcl_pytorch\chap2> python multi-cat.py
epoch 1: training loss = 1.8184418678283691, validation loss = 1.7703135013580322
epoch 2: training loss = 1.7660441398620605, validation loss = 1.7609056234359741
... ...
epoch 99: training loss = 1.765865445137024, validation loss = 1.7917983531951904
epoch 100: training loss = 1.758225679397583, validation loss = 1.7615996599197388

time to train model: 1001.4624395999999

prediction accuracy is 0.695859872611465

time to test model: 1.6990639000141528

Dropoout(p=0.25) & EarlyStop(8)

(mclp-3.12) PS F:\learn\mcl_pytorch\chap2> python multi-cat.py
epoch 1: training loss = 1.8184418678283691, validation loss = 1.7703135013580322
epoch 2: training loss = 1.7660441398620605, validation loss = 1.7609056234359741
... ...
epoch 11: training loss = 1.738189458847046, validation loss = 1.7366585731506348
epoch 12: training loss = 1.737826943397522, validation loss = 1.7418655157089233
epoch 13: training loss = 1.7358711957931519, validation loss = 1.7380039691925049
epoch 14: training loss = 1.734538197517395, validation loss = 1.7506725788116455
epoch 15: training loss = 1.7397111654281616, validation loss = 1.7422360181808472
epoch 16: training loss = 1.7360037565231323, validation loss = 1.7536872625350952
epoch 17: training loss = 1.7319772243499756, validation loss = 1.7437314987182617
epoch 18: training loss = 1.7376022338867188, validation loss = 1.744455099105835
epoch 19: training loss = 1.7387878894805908, validation loss = 1.7498186826705933

time to train model: 185.28267669997877

prediction accuracy is 0.6934713375796179

time to test model: 1.6849242999742273

Dropout(p=0.25) & EarlyStop(12)

Not particularly good performance. Let’s up the wait cycles to 12 from the default 8.

(mclp-3.12) PS F:\learn\mcl_pytorch\chap2> python multi-cat.py
epoch 1: training loss = 1.8184418678283691, validation loss = 1.7703135013580322
epoch 2: training loss = 1.7660441398620605, validation loss = 1.7609056234359741
... ...
epoch 11: training loss = 1.738189458847046, validation loss = 1.7366585731506348
epoch 12: training loss = 1.737826943397522, validation loss = 1.7418655157089233
epoch 13: training loss = 1.7358711957931519, validation loss = 1.7380039691925049
epoch 14: training loss = 1.734538197517395, validation loss = 1.7506725788116455
epoch 15: training loss = 1.7397111654281616, validation loss = 1.7422360181808472
epoch 16: training loss = 1.7360037565231323, validation loss = 1.7536872625350952
epoch 17: training loss = 1.7319772243499756, validation loss = 1.7437314987182617
epoch 18: training loss = 1.7376022338867188, validation loss = 1.744455099105835
epoch 19: training loss = 1.7387878894805908, validation loss = 1.7498186826705933
epoch 20: training loss = 1.7342194318771362, validation loss = 1.7504445314407349
epoch 21: training loss = 1.73516047000885, validation loss = 1.7515453100204468
epoch 22: training loss = 1.7338981628417969, validation loss = 1.820580005645752
epoch 23: training loss = 1.731840968132019, validation loss = 1.7761754989624023

time to train model: 224.5111149000004

prediction accuracy is 0.6847133757961783

time to test model: 1.6120651999954134

Dropout(p=0.20) & EarlyStop(12)

Even worse. Let’s see if we can improve things with a lower probability for Dropout

(mclp-3.12) PS F:\learn\mcl_pytorch\chap2> python multi-cat.py
epoch 1: training loss = 1.7902714014053345, validation loss = 1.7459933757781982
epoch 2: training loss = 1.7353540658950806, validation loss = 1.7412049770355225
epoch 3: training loss = 1.7215162515640259, validation loss = 1.727048397064209
... ...
epoch 27: training loss = 1.705673098564148, validation loss = 1.6968224048614502
epoch 28: training loss = 1.7004787921905518, validation loss = 1.6970628499984741
epoch 29: training loss = 1.7018179893493652, validation loss = 1.7099974155426025
epoch 30: training loss = 1.7047582864761353, validation loss = 1.718260407447815
epoch 31: training loss = 1.7127240896224976, validation loss = 1.7111423015594482
epoch 32: training loss = 1.7021905183792114, validation loss = 1.7154974937438965
epoch 33: training loss = 1.703952670097351, validation loss = 1.718424916267395
epoch 34: training loss = 1.7015137672424316, validation loss = 1.704266905784607
epoch 35: training loss = 1.694817066192627, validation loss = 1.7009721994400024
epoch 36: training loss = 1.6981953382492065, validation loss = 1.7077285051345825
epoch 37: training loss = 1.703160285949707, validation loss = 1.7209525108337402
epoch 38: training loss = 1.6995770931243896, validation loss = 1.7061223983764648
epoch 39: training loss = 1.710923194885254, validation loss = 1.7131465673446655

time to train model: 374.70451449998654

prediction accuracy is 0.734375

time to test model: 1.6185605000064243

A slight improvement.

Dropout(p=0.16) & EarlyStop(9)

Okay, reduce dropout probability a little more, and set early stop at 9 epochs without loss improvement on validation set.

(mclp-3.12) PS F:\learn\mcl_pytorch\chap2> python multi-cat.py
epoch 1: training loss = 1.7686419486999512, validation loss = 1.704501748085022
epoch 2: training loss = 1.711744785308838, validation loss = 1.6997143030166626
epoch 3: training loss = 1.7013624906539917, validation loss = 1.693571925163269
... ...
epoch 19: training loss = 1.674493432044983, validation loss = 1.6730879545211792
epoch 20: training loss = 1.6722667217254639, validation loss = 1.6846728324890137
epoch 21: training loss = 1.6740472316741943, validation loss = 1.680519938468933
epoch 22: training loss = 1.672454595565796, validation loss = 1.6903433799743652
epoch 23: training loss = 1.6759159564971924, validation loss = 1.6896322965621948
epoch 24: training loss = 1.6699936389923096, validation loss = 1.6915292739868164
epoch 25: training loss = 1.6795790195465088, validation loss = 1.7103519439697266
epoch 26: training loss = 1.6789908409118652, validation loss = 1.6974619626998901
epoch 27: training loss = 1.6792508363723755, validation loss = 1.6863120794296265
epoch 28: training loss = 1.683669090270996, validation loss = 1.6815105676651

time to train model: 278.7310662000091

prediction accuracy is 0.8336982484076433

time to test model: 1.6716616000048816

That seems to have improved things.

And, in the end, we used the test set accuracy to help determine a better value for a couple of hyperparameters. Unfortunately, I have since done some reading that implies that is not a good idea. And, that for that purpose we should be using the validation dataset. Will need to do more research.

Save and Load Model

Okay, what if I want to save the model for future use without training it everytime.

Save Model to File

This took me a while to sort out and I am still nowhere near sure I am doing it the right way. But here’s my code for saving and loading the model. I didn’t want at this point to rework the model into a class so that I could also save that. So, I have to make sure the model is defined before I load and set it up for further predictions.

At the end of the file, I have the code to save the file in an if block.

... ...
# model from/to file
ld_model = False
sv_model = True
... ...
# save model for later reload, rather than training again
if sv_model:
  checkpt = {
    'state_dict': mc_model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'epoch': 29,
    'trn_loss': trn_loss
  }
  torch.save(checkpt, 'mc_checkpt.pth')

Load Model from File

And I have put the training code into an if block as well. The else loads the model from the saved file. I am using not ld_model to mean ’train the model’.

if not ld_model:
  for i in range(max_ep):
    trn_loss = ex_epoch()
    val_loss = get_val_loss()
  ... ....
else:
  # load parameters from file and put in evaluation mode
  # model must already be defined/instatiated above
  chk_pt = torch.load('mc_checkpt.pth')
  mc_model.load_state_dict(chk_pt['state_dict'])
  optimizer.load_state_dict(chk_pt['optimizer'])
  trn_loss = chk_pt['trn_loss']

Use Loaded Model Against Test Set

Loading the model from file and running it against the test set produced the following output.

(mclp-3.12) PS F:\learn\mcl_pytorch\chap2> python multi-cat.py

prediction accuracy is 0.8331011146496815

time to test model: 2.05665849999059

Fini

That’s it for this post. Lots of learning, lots of research and reading. Maybe a little better understanding of what’s going on.

Not sure what I’ll be up to in the next post. But, I might look at putting my model into a class, so that I can look at saving that to file along with all the parameters.

Until next time, may your coding and/or research prove fruitful.

Resources