Back to blog
← View series: ibm ai engineering

~/blog

6.2lab_predicting _MNIST_using_Softmax_v2

Apr 1, 20264 min readBy Mohammed Vasim
AIMachine LearningLLMPyTorchTensorFlowGenerative AILangChainAI Agents

Softmax Classifier

Objective

  • How to classify handwritten digits from the MNIST database by using Softmax classifier.

Table of Contents

In this lab, you will use a single layer Softmax to classify handwritten digits from the MNIST database.

Estimated Time Needed: 25 min


Preparation

We'll need the following libraries

python
# Import the libraries we need for this lab

# Using the following line code to install the torchvision library
# !mamba install -y torchvision

!pip install torchvision==0.9.1 torch==1.8.1 
import torch 
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import matplotlib.pylab as plt
import numpy as np

Use the following function to plot out the parameters of the Softmax function:

python
# The function to plot parameters

def PlotParameters(model): 
    W = model.state_dict()['linear.weight'].data
    w_min = W.min().item()
    w_max = W.max().item()
    fig, axes = plt.subplots(2, 5)
    fig.subplots_adjust(hspace=0.01, wspace=0.1)
    for i, ax in enumerate(axes.flat):
        if i < 10:
            
            # Set the label for the sub-plot.
            ax.set_xlabel("class: {0}".format(i))

            # Plot the image.
            ax.imshow(W[i, :].view(28, 28), vmin=w_min, vmax=w_max, cmap='seismic')

            ax.set_xticks([])
            ax.set_yticks([])

        # Ensure the plot is shown correctly with multiple plots
        # in a single Notebook cell.
    plt.show()

Use the following function to visualize the data:

python
# Plot the data

def show_data(data_sample):
    plt.imshow(data_sample[0].numpy().reshape(28, 28), cmap='gray')
    plt.title('y = ' + str(data_sample[1]))

Make Some Data

Load the training dataset by setting the parameters train to True and convert it to a tensor by placing a transform object in the argument transform.

python
# Create and print the training dataset

train_dataset = dsets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
print("Print the training dataset:\n ", train_dataset)

Load the testing dataset and convert it to a tensor by placing a transform object in the argument transform.

python
# Create and print the validating dataset

validation_dataset = dsets.MNIST(root='./data', download=True, transform=transforms.ToTensor())
print("Print the validating dataset:\n ", validation_dataset)

You can see that the data type is long:

python
# Print the type of the element

print("Type of data element: ", type(train_dataset[0][1]))

Each element in the rectangular tensor corresponds to a number that represents a pixel intensity as demonstrated by the following image:

MNIST elementsMNIST elements

In this image, the values are inverted i.e back represents wight.

Print out the label of the fourth element:

python
# Print the label

print("The label: ", train_dataset[3][1])

The result shows the number in the image is 1

Plot the fourth sample:

python
# Plot the image

print("The image: ", show_data(train_dataset[3]))

You see that it is a 1. Now, plot the third sample:

python
# Plot the image

show_data(train_dataset[2])

Build a Softmax Classifer

Build a Softmax classifier class:

python
# Define softmax classifier class

class SoftMax(nn.Module):
    
    # Constructor
    def __init__(self, input_size, output_size):
        super(SoftMax, self).__init__()
        self.linear = nn.Linear(input_size, output_size)
        
    # Prediction
    def forward(self, x):
        z = self.linear(x)
        return z

The Softmax function requires vector inputs. Note that the vector shape is 28x28.

python
# Print the shape of train dataset

train_dataset[0][0].shape

Flatten the tensor as shown in this image:

Flattern ImageFlattern Image

The size of the tensor is now 784.

Flattern ImageFlattern Image

Set the input size and output size:

python
# Set input size and output size

input_dim = 28 * 28
output_dim = 10

Define the Softmax Classifier, Criterion Function, Optimizer, and Train the Model

python
# Create the model

model = SoftMax(input_dim, output_dim)
print("Print the model:\n ", model)

View the size of the model parameters:

python
# Print the parameters

print('W: ',list(model.parameters())[0].size())
print('b: ',list(model.parameters())[1].size())

You can cover the model parameters for each class to a rectangular grid:

Plot the model parameters for each class as a square image:

python
# Plot the model parameters for each class

PlotParameters(model)

Define the learning rate, optimizer, criterion, data loader:

python
# Define the learning rate, optimizer, criterion and data loader

learning_rate = 0.1
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100)
validation_loader = torch.utils.data.DataLoader(dataset=validation_dataset, batch_size=5000)

Train the model and determine validation accuracy (should take a few minutes):

python
# Train the model

n_epochs = 10
loss_list = []
accuracy_list = []
N_test = len(validation_dataset)

def train_model(n_epochs):
    for epoch in range(n_epochs):
        for x, y in train_loader:
            optimizer.zero_grad()
            z = model(x.view(-1, 28 * 28))
            loss = criterion(z, y)
            loss.backward()
            optimizer.step()
            
        correct = 0
        # perform a prediction on the validationdata  
        for x_test, y_test in validation_loader:
            z = model(x_test.view(-1, 28 * 28))
            _, yhat = torch.max(z.data, 1)
            correct += (yhat == y_test).sum().item()
        accuracy = correct / N_test
        loss_list.append(loss.data)
        accuracy_list.append(accuracy)

train_model(n_epochs)

Analyze Results

Plot the loss and accuracy on the validation data:

python
# Plot the loss and accuracy

fig, ax1 = plt.subplots()
color = 'tab:red'
ax1.plot(loss_list,color=color)
ax1.set_xlabel('epoch',color=color)
ax1.set_ylabel('total loss',color=color)
ax1.tick_params(axis='y', color=color)
    
ax2 = ax1.twinx()  
color = 'tab:blue'
ax2.set_ylabel('accuracy', color=color)  
ax2.plot( accuracy_list, color=color)
ax2.tick_params(axis='y', color=color)
fig.tight_layout()

View the results of the parameters for each class after the training. You can see that they look like the corresponding numbers.

python
# Plot the parameters

PlotParameters(model)

We Plot the first five misclassified samples and the probability of that class.

python
# Plot the misclassified samples
Softmax_fn=nn.Softmax(dim=-1)
count = 0
for x, y in validation_dataset:
    z = model(x.reshape(-1, 28 * 28))
    _, yhat = torch.max(z, 1)
    if yhat != y:
        show_data((x, y))
        plt.show()
        print("yhat:", yhat)
        print("probability of class ", torch.max(Softmax_fn(z)).item())
        count += 1
    if count >= 5:
        break

We Plot the first five correctly classified samples and the probability of that class, we see the probability is much larger.

python
# Plot the classified samples
Softmax_fn=nn.Softmax(dim=-1)
count = 0
for x, y in validation_dataset:
    z = model(x.reshape(-1, 28 * 28))
    _, yhat = torch.max(z, 1)
    if yhat == y:
        show_data((x, y))
        plt.show()
        print("yhat:", yhat)
        print("probability of class ", torch.max(Softmax_fn(z)).item())
        count += 1
    if count >= 5:
        break

About the Authors:

Joseph Santarcangelo has a PhD in Electrical Engineering, his research focused on using machine learning, signal processing, and computer vision to determine how videos impact human cognition. Joseph has been working for IBM since he completed his PhD.

Other contributors: Michelle Carey, Mavis Zhou


© IBM Corporation. All rights reserved.

Comments (0)

No comments yet. Be the first to comment!

Leave a comment