5. Predict with a pre-trained model¶
A saved model can be used in multiple places, such as to continue training, to fine tune the model, and for prediction. In this tutorial we will discuss how to predict new examples using a pretrained model.
Please run the previous tutorial to train the network and save its parameters to file. You will need this file to run the following steps.
from mxnet import nd from mxnet import gluon from mxnet.gluon import nn from mxnet.gluon.data.vision import datasets, transforms import matplotlib.pyplot as plt
To start, we will copy a simple model’s definition.
net = nn.Sequential() with net.name_scope(): net.add( nn.Conv2D(channels=6, kernel_size=5, activation='relu'), nn.MaxPool2D(pool_size=2, strides=2), nn.Conv2D(channels=16, kernel_size=3, activation='relu'), nn.MaxPool2D(pool_size=2, strides=2), nn.Flatten(), nn.Dense(120, activation="relu"), nn.Dense(84, activation="relu"), nn.Dense(10) )
In the last section, we saved all parameters into a file, now let’s load it back.
Remember the data transformation we did for training? Now we need the same transformation for predicting.
transformer = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(0.13, 0.31)])
Now let’s try to predict the first six images in the validation dataset
and store the predictions into
mnist_valid = datasets.FashionMNIST(train=False) X, y = mnist_valid[:6] preds =  for x in X: x = transformer(x).expand_dims(axis=0) pred = net(x).argmax(axis=1) preds.append(pred.astype('int32').asscalar())
Finally, we visualize the images and compare the prediction with the ground truth.
_, figs = plt.subplots(1, 6, figsize=(15, 15)) text_labels = [ 't-shirt', 'trouser', 'pullover', 'dress,', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot' ] for f,x,yi,pyi in zip(figs, X, y, preds): f.imshow(x.reshape((28,28)).asnumpy()) ax = f.axes ax.set_title(text_labels[yi]+'\n'+text_labels[pyi]) ax.title.set_fontsize(20) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) plt.show()
5.3. Predict with models from Gluon model zoo¶
The LeNet trained on FashionMNIST is a good example to start with, but too simple to predict real-life pictures. Instead of training large-scale model from scratch, Gluon model zoo provides multiple pre-trained powerful models. For example, we can download and load a pre-trained ResNet-50 V2 model on the ImageNet dataset.
from mxnet.gluon.model_zoo import vision as models from mxnet.gluon.utils import download from mxnet import image net = models.resnet50_v2(pretrained=True)
We also download and load the text labels for each class.
url = 'http://data.mxnet.io/models/imagenet/synset.txt' fname = download(url) with open(fname, 'r') as f: text_labels = [' '.join(l.split()[1:]) for l in f]
We randomly pick a dog image from Wikipedia as a test image, download and read it.
url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/b/b5/\ Golden_Retriever_medium-to-light-coat.jpg/\ 365px-Golden_Retriever_medium-to-light-coat.jpg' fname = download(url) x = image.imread(fname)
Following the conventional way of preprocessing ImageNet data: 1. resize the short edge into 256 pixes, 2. then perform a center crop to obtain a 224-by-224 image The following code uses the image processing functions provided in the MXNet image module.
x = image.resize_short(x, 256) x, _ = image.center_crop(x, (224,224)) plt.imshow(x.asnumpy()) plt.show()
Now you may know it is a golden retriever (You can also infer it from the image URL).
The futher data transformation is similar to FashionMNIST except that we subtract the RGB means and divide by the corresponding variances to normalize each color channel.
def transform(data): data = data.transpose((2,0,1)).expand_dims(axis=0) rgb_mean = nd.array([0.485, 0.456, 0.406]).reshape((1,3,1,1)) rgb_std = nd.array([0.229, 0.224, 0.225]).reshape((1,3,1,1)) return (data.astype('float32') / 255 - rgb_mean) / rgb_std
Now we can recognize the object in the image now. We perform an additional softmax on the output to obtain probability scores. And then print the top-5 recognized objects.
prob = net(transform(x)).softmax() idx = prob.topk(k=5) for i in idx: i = int(i.asscalar()) print('With prob = %.5f, it contains %s' % ( prob[0,i].asscalar(), text_labels[i]))
With prob = 0.98642, it contains golden retriever With prob = 0.00485, it contains Irish setter, red setter With prob = 0.00319, it contains Labrador retriever With prob = 0.00153, it contains English setter With prob = 0.00133, it contains Brittany spaniel
As can be seen, the model is fairly confident the image contains a golden retriever.