Building a Neural Network (CNN) for Handwritten Digit Recognition Using PyTorch

Building a Neural Network (CNN) for Handwritten Digit Recognition Using PyTorch

1. Introduction

Handwritten digit recognition is one of the classic problems in machine learning and serves as an excellent starting point for understanding neural networks. From postal code detection to document digitization, digit recognition has practical applications in a variety of domains.

In this tutorial, we’ll explore how to create a Convolutional Neural Network (CNN) capable of recognizing handwritten digits from the MNIST dataset — a benchmark dataset of 28×28 grayscale images of digits (0–9). We’ll use PyTorch, one of the most popular deep learning libraries, to build, train, and evaluate our model.

But we won’t stop there! Once the model achieves high accuracy, we’ll deploy it as a web application using Flask, making it accessible for real-world usage. This end-to-end guide will help you understand not just the fundamentals of CNNs, but also how to bridge the gap between research and deployment.

Whether you’re a beginner eager to dive into deep learning or an experienced developer looking to solidify your understanding of CNNs, this guide has something for everyone. Let’s get started!

2. Model Architecture

To train our CNN model for handwritten digit recognition, we utilize the MNIST dataset, a benchmark dataset containing 28×28 grayscale images of digits ranging from 0 to 9. Here’s a breakdown of our training process:

The architecture of our CNN consists of:

  • Two convolutional layers (conv1, conv2): These layers extract spatial features from the input images. Each layer uses a 3×3 kernel, with padding and ReLU activation.
  • Max Pooling layers (pool): These reduce the spatial dimensions, helping the model focus on important features.
  • Fully connected layers (fc1, fc2): These transform the spatial features into class scores, with fc2 producing the final output for the 10 digit classes.
  • Dropout (dropout): A dropout layer is included to prevent overfitting by randomly deactivating neurons during training.

Here’s the code for the architecture:

import torch.nn as nn
import torch.nn.functional as F

class MNIST_CNN(nn.Module):
def __init__(self):
super(MNIST_CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(32 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
self.dropout = nn.Dropout(0.5)

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 32 * 7 * 7)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x

3. Data Pre-processing

The MNIST dataset is loaded using PyTorch’s torchvision library. It’s normalized to ensure consistent pixel intensity distribution and converted to tensors for processing.

import torchvision
from torchvision.transforms import transforms
from torch.utils.data import DataLoader

def prepare_data(batch_size):
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)

return trainloader, testloader

4. Model Training Process

Training a Convolutional Neural Network (CNN) to recognize handwritten digits involves several steps: defining the training loop, calculating loss, and optimizing the model weights. Below, we detail the training process using PyTorch.

Our training process is encapsulated in two key functions:

  1. train_model : Handles the main training loop over epochs, computes loss, and updates weights using backpropagation.
  2. save_model: Saves the trained model’s state to a file for later use in inference.

Here’s the training script:

import torch

def train_model(model, trainloader, criterion, optimizer, device, num_epochs):
model.train()
for epoch in range(num_epochs):
running_loss = 0.0
for i, (inputs, labels) in enumerate(trainloader):
# Move data to the specified device (GPU or CPU)
inputs, labels = inputs.to(device), labels.to(device)

optimizer.zero_grad() # Zero the parameter gradients
outputs = model(inputs) # Forward pass
loss = criterion(outputs, labels) # Compute the loss
loss.backward() # Backward pass and optimization
optimizer.step()

# Track loss for every 100 batches
running_loss += loss.item()
if i % 100 == 99:
print(f"Epoch {epoch + 1}, Batch {i + 1}, Loss: {running_loss / 100:.4f}")
running_loss = 0.0
print("Finished Training!")

def save_model(model, path):
torch.save(model.state_dict(), path)
print(f"Model saved to {path}")1

The training script follows these steps:

  1. Move Data to Device: The inputs and labels are moved to the specified device (CPU or GPU).
  2. Zero Gradients: Any previously stored gradients are cleared before calculating new ones.
  3. Forward Pass: The input data is passed through the model to obtain predictions.
  4. Loss Calculation: The difference between predictions and true labels is computed using the loss function.
  5. Backpropagation: Gradients are calculated and used to update the model’s weights via the optimizer.
  6. Monitor Loss: Loss values are printed at regular intervals to track training progress.

Saving the Trained Model

Once training is complete, the model’s weights are saved to a file using the save_model function. This enables us to load the pre-trained model for inference later:

save_model(model, "mnist_cnn.pth")

Why Save the Model?

Saving the trained model ensures that the weights and parameters can be reloaded for predictions without retraining. This is especially useful for deploying the model in a production environment or as a web service.

5. Evaluating the CNN Model

Once the CNN model is trained, it’s essential to evaluate its performance on unseen test data to ensure its generalization. In this section, we use metrics such as the confusion matrix and classification report to assess the model’s accuracy.

Our evaluation script leverages PyTorch for inference and scikit-learn for computing metrics:

import torch
from sklearn.metrics import classification_report, confusion_matrix

def evaluate_model(model, testloader, device):
model.eval()
all_targets = []
all_predictions = []

with torch.no_grad():
for inputs, labels in testloader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs, 1)
all_targets.extend(labels.cpu().numpy())
all_predictions.extend(predicted.cpu().numpy())

print("Confusion Matrix:")
print(confusion_matrix(all_targets, all_predictions))
print("\nClassification Report:")
print(classification_report(all_targets, all_predictions))2

Evaluation Process

  1. Set Model to Evaluation Mode: The model.eval() method ensures that certain layers like dropout and batch normalization behave appropriately during inference.
  2. Disable Gradient Computation:
    By using torch.no_grad(), we skip gradient calculation, making inference faster and memory-efficient.
  3. Iterate Through Test Data:
  • Pass each batch of test data through the model.
  • Use torch.max to determine the predicted class with the highest probability.

4. Compute Metrics:

  • Confusion Matrix: Displays the performance of the model by showing true positives, false positives, and false negatives for each class.
  • Classification Report: Provides precision, recall, F1-score, and support for each class, offering deeper insights into model performance.

Sample Output

After running the evaluation script, you might see outputs like:

Confusion Matrix:

[[ 970    0    1    1    0    2    4    1    1    0]
[ 0 1126 2 1 0 1 3 0 2 0]
...
[ 0 1 0 0 0 0 0 0 2 1006]]

Classification Report:

   precision    recall  f1-score   support

0 0.99 1.00 0.99 980
1 1.00 1.00 1.00 1135
...
9 0.99 0.99 0.99 1009

accuracy 0.99 10000
macro avg 0.99 0.99 0.99 10000
weighted avg 0.99 0.99 0.99 10000

Interpreting the Results

  • Accuracy: The overall accuracy indicates the proportion of correctly classified digits.
  • Precision: Measures the ability of the model to avoid false positives.
  • Recall: Measures the ability to capture true positives.
  • F1-Score: The harmonic mean of precision and recall, providing a balanced measure of performance.

6. Deploying the Trained CNN as a Flask Web App

After training our Convolutional Neural Network (CNN) model, the next step is to deploy it so users can interact with it through a simple web interface. We’ll use Flask, a lightweight Python web framework, to serve our trained model.

Prerequisites

Before proceeding, ensure the following:

  • The trained model (mnist_cnn.pth) is saved in the project directory.
  • Flask is installed in your environment. If not, install it using:
pip install Flask

The Flask App

Here is the Flask application that allows users to upload an image of a handwritten digit and get the predicted class:

import torch
from PIL import Image
from model.model import MNIST_CNN
import torchvision.transforms as transforms
from flask import Flask, request, jsonify, render_template

# Initialize the Flask application
app=Flask(__name__,template_folder='../templates')
@app.route('/')
def index():
return render_template('index.html')

# Load the trained model
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=MNIST_CNN()
model.load_state_dict(torch.load("mnist_cnn.pth",map_location=device))
model.to(device)
model.eval()

# Define data preprocessing steps
transform=transforms.Compose([
transforms.Grayscale(),
transforms.ToTensor(),
transforms.Resize((28,28)),
transforms.Normalize((0.5,),(0.5,))
])

# Define the prediction function
def predict(image):
image=transform(image).unsqueeze(0).to(device)

with torch.no_grad():
output=model(image)
_,predicted=torch.max(output,1)
return predicted.item()

@app.route('/predict',methods=['POST'])
def predict_route():
if 'image' not in request.files:
return jsonify({'error': 'No file part'}), 400

file=request.files['image']
if file.filename=='':
return jsonify({'error': 'No selected file'}), 400

try:
# Open the image
image=Image.open(file.stream)
predicted_class=predict(image)
return render_template('index.html', predicted_class=predicted_class)
except Exception as e:
return jsonify({'error': str(e)}), 500

if __name__ == '__main__':
app.run(debug=True)

Key Points to Note

  1. Preprocessing: The uploaded image is resized to 28×28, normalized, and converted to a tensor to match the input format expected by the trained model.
  2. Model Inference: The trained model is loaded using torch.load() and used in evaluation mode (model.eval()).
  3. Web Interface: The index.html file in the templates directory is used as the front-end for uploading images and displaying results.

The HTML Interface

Here’s a simple HTML template (index.html) that pairs with the Flask app:

<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Digit Prediction</title>
</head>
<body>
<h1>Upload an image of a digit</h1>
<form action="/predict" method="post" enctype="multipart/form-data">
<input type="file" name="image" accept="image/png, image/jpeg, image/webp">
<input type="submit" value="Upload">
</form>
{% if predicted_class is not none %}
<h2>Predicted Digit: {{ predicted_class }}</h2>
{% endif %}
</body>
</html>

Running the Flask App

  1. Save the Flask app script as app.py and the HTML file as index.html in a templates folder.
  2. Start the Flask app by running:
   python -m app.app

3. Open your browser and navigate to http://127.0.0.1:5000. You should see the web interface where you can upload an image.

Output Example

After uploading an image of a handwritten digit, the application will display the predicted digit on the same page.

Conclusion

In this article, we explored the journey of building a Convolutional Neural Network (CNN) to recognize handwritten digits using the MNIST dataset. From data preprocessing and model training to evaluation and deployment, each step showcased the power of PyTorch for deep learning and Flask for creating interactive web applications.

This project not only highlights the practical application of neural networks but also demonstrates how to bridge the gap between complex AI models and end-user accessibility. Deploying the model as a web app enables anyone to upload handwritten digit images and receive predictions in real time.

  1. ↩︎
  2. ↩︎

Comments

No comments yet. Why don’t you start the discussion?

Leave a Reply

Your email address will not be published. Required fields are marked *