4. Train the neural network

In this section, we will discuss how to train the previously defined network with data. We first import the libraries. The new ones are mxnet.init for more weight initialization methods, the datasets and transforms to load and transform computer vision datasets, matplotlib for drawing, and time for benchmarking.

In [1]:
# Uncomment the following line if matplotlib is not installed.
# !pip install matplotlib

from mxnet import nd, gluon, init, autograd
from mxnet.gluon import nn
from mxnet.gluon.data.vision import datasets, transforms
import matplotlib.pyplot as plt
from time import time

4.1. Get data

The handwritten digit MNIST dataset is one of the most commonly used datasets in deep learning. But it is too simple to get a 99% accuracy. Here we use a similar but slightly more complicated dataset called FashionMNIST. The goal is no longer to classify numbers, but clothing types instead.

The dataset can be automatically downloaded through Gluon’s data.vision.datasets module. The following code downloads the training dataset and shows the first example.

In [2]:
mnist_train = datasets.FashionMNIST(train=True)
X, y = mnist_train[0]
('X shape: ', X.shape, 'X dtype', X.dtype, 'y:', y)
('X shape: ', (28, 28, 1), 'X dtype', numpy.uint8, 'y:', 2)

Each example in this dataset is a \(28\times 28\) size grey image, which is presented as NDArray with the shape format of (height, width, channel). The label is a numpy scalar.

Next, we visualize the first six examples.

In [3]:
text_labels = [
    't-shirt', 'trouser', 'pullover', 'dress,', 'coat',
    'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'
X, y = mnist_train[0:6]
# plot images
_, figs = plt.subplots(1, X.shape[0], figsize=(15, 15))
for f,x,yi in zip(figs, X,y):
    # 3D->2D by removing the last channel dim
    ax = f.axes

In order to feed data into a Gluon model, we need to convert the images to the (channel, height, weight) format with a floating point data type. It can be done by transforms.ToTensor. In addition, we normalize all pixel values with transforms.Normalize with the real mean 0.13 and variance 0.31. We chain these two transforms together and apply it to the first element of the data pair, namely the images.

In [4]:
transformer = transforms.Compose([
    transforms.Normalize(0.13, 0.31)])

mnist_train = mnist_train.transform_first(transformer)

FashionMNIST is a subclass of gluon.data.Dataset, which defines how to get the i-th example. In order to use it in training, we need to get a (randomized) batch of examples. It can be easily done by gluon.data.DataLoader. Here we use four works to process data in parallel, which is often necessary especially for complex data transforms.

In [5]:
batch_size = 256

train_data = gluon.data.DataLoader(
    mnist_train, batch_size=batch_size, shuffle=True, num_workers=4)

The returned train_data is an iterator that yields batches of images and labels pairs.

In [6]:
for data, label in train_data:
    print(data.shape, label.shape)
(256, 1, 28, 28) (256,)

Finally, we create a validation dataset and data loader.

In [7]:
mnist_valid = gluon.data.vision.FashionMNIST(train=False)
valid_data = gluon.data.DataLoader(
    batch_size=batch_size, num_workers=4)

4.2. Define the model

We reimplement the same LeNet introduced before. One difference here is that we changed the weight initialization method to Xavier, which is a popular choice for deep convolutional neural networks.

In [8]:
net = nn.Sequential()
with net.name_scope():
        nn.Conv2D(channels=6, kernel_size=5, activation='relu'),
        nn.MaxPool2D(pool_size=2, strides=2),
        nn.Conv2D(channels=16, kernel_size=3, activation='relu'),
        nn.MaxPool2D(pool_size=2, strides=2),
        nn.Dense(120, activation="relu"),
        nn.Dense(84, activation="relu"),

Besides the neural network, we need to define the loss function and optimization method for training. We will use standard softmax cross entropy loss for classification problems. It first performs softmax on the output to obtain the predicted probability, and then compares the label with the cross entropy.

In [9]:
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()

The optimization method we picked is the standard stochastic gradient descent with constant learning rate of 0.1.

In [10]:
trainer = gluon.Trainer(net.collect_params(),
                        'sgd', {'learning_rate': 0.1})

The trainer is created with all parameters (both weights and gradients) in net. Later on, we only need to call the step method to update its weights.

4.3. Train

We create an auxiliary function to calculate the model accuracy.

In [11]:
def acc(output, label):
    # output: (batch, num_output) float32 ndarray
    # label: (batch, ) int32 ndarray
    return (output.argmax(axis=1) ==

Now we can implement the complete training loop.

In [12]:
for epoch in range(10):
    train_loss, train_acc, valid_acc = 0., 0., 0.
    tic = time()
    for data, label in train_data:
        # forward + backward
        with autograd.record():
            output = net(data)
            loss = softmax_cross_entropy(output, label)
        # update parameters
        # calculate traing metrics
        train_loss += loss.mean().asscalar()
        train_acc += acc(output, label)

    # calculate validation accuracy
    for data, label in valid_data:
        valid_acc += acc(net(data), label)

    print("Epoch %d: Loss: %.3f, Train acc %.3f, Test acc %.3f, \
Time %.1f sec" % (
        epoch, train_loss/len(train_data),
        valid_acc/len(valid_data), time()-tic))
Epoch 0: Loss: 0.828, Train acc 0.696, Test acc 0.810, Time 11.9 sec
Epoch 1: Loss: 0.472, Train acc 0.824, Test acc 0.854, Time 12.1 sec
Epoch 2: Loss: 0.402, Train acc 0.852, Test acc 0.859, Time 12.9 sec
Epoch 3: Loss: 0.360, Train acc 0.868, Test acc 0.875, Time 13.2 sec
Epoch 4: Loss: 0.336, Train acc 0.876, Test acc 0.882, Time 13.0 sec
Epoch 5: Loss: 0.320, Train acc 0.881, Test acc 0.884, Time 12.6 sec
Epoch 6: Loss: 0.303, Train acc 0.888, Test acc 0.879, Time 12.9 sec
Epoch 7: Loss: 0.294, Train acc 0.891, Test acc 0.888, Time 12.7 sec
Epoch 8: Loss: 0.280, Train acc 0.896, Test acc 0.893, Time 13.0 sec
Epoch 9: Loss: 0.269, Train acc 0.900, Test acc 0.899, Time 13.0 sec

4.4. Save the model

Finally, we save the trained parameters onto disk, so that we can use them later.

In [13]: