Transfer Learning with VGG16 and Keras

How to use a state-of-the-art trained NN to solve your image classification problem

Gabriel Cassimiro
Towards Data Science

--

The main goal of this article is to demonstrate with code and examples how can you use an already trained CNN (convolutional neural network) to solve your specific problem.

Convolutional Networks are great for image problems however, they are computationally expensive if you use a big architecture and don’t have a GPU. For that, we have two solutions:

GPUs

GPUs are much more efficient to train NNs but they are not that common on regular computers. So that is where google colab come to save us. They offer virtual machines with GPUs up to 16 GB of RAM and the best part of it all: It is Free.

But even with those upgraded specs, you can still struggle when training a brand new CNN. That’s where Transfer Learning can help you achieve great results with less expensive computation.

Transfer Learning

So what is transfer learning?

To better explain that we must first understand the basic architecture of a CNN.

Image by Author

A CNN can be divided into two main parts: Feature learning and classification.

Feature Learning

In this part, the main goal of the NN is to find patterns in the pixels of the images that can be useful to identify the targets of the classification. That happens in the convolution layers of the network that specializes in those patterns for the problem at hand.

I’m not going deep into how this works underneath the hood, but if you want to dig deeper I highly recommend this article and this amazing video.

Classification

Now we want to use those patterns to classify our images to their correct label. This part of the network does exactly that job, it uses the inputs from the previous layers to find the best class to your matched patterns in the new image.

Definition

So now we can define Transfer Learning in our context as utilizing the feature learning layers of a trained CNN to classify a different problem than the one it was created for.

In other words, we use the patterns that the NN found to be useful to classify images of a given problem to classify a completely different problem without retraining that part of the network.

Now I am going to demonstrate how you can do that with Keras, and prove that for a lot of cases this gives better results than training a new network.

Transfer Learning With Keras

I will use for this demonstration a famous NN called VGG16. This is its architecture:

Image by Author

This network was trained on the ImageNet dataset, containing more than 14 million high-resolution images belonging to 1000 different labels.

If you want to dig deeper into this specific model you can study this paper.

Dataset

For this demonstration, I will use the tf_flowers dataset. Just as a reminder: The VGG16 network was not trained to classify different kinds of flowers.

This is what the data looks like:

Image by Author

Finally…

The Code

First, we have to load the dataset from TensorFlow:

Now we can load the VGG16 model.

We use Include_top=False to remove the classification layer that was trained on the ImageNet dataset and set the model as not trainable. Also, we used the preprocess_input function from VGG16 to normalize the input data.

We can run this code to check the model summary.

base_model.summary()
Image by Author

Two main points: the model has over 14 Million trained parameters and ends with a maxpooling layer that belongs to the Feature Learning part of the network.

Now we add the last layers for our specific problem.

And compile and fit the model.

Evaluating this model on the test set we got a 96% Accuracy!

That’s it!

It is this simple. And it is kind of beautiful right?

How we can find some patterns in the world that can be used to identify completely different things.

If you want to check out the complete code and a jupyter notebook, here’s the GitHubrepo:

Extra: comparing to hand-made model

To be sure that this approach can be better in both computational resources and precision I created a hand-made simple model for this problem.

This is the code:

I used the same final layers and fit parameters to be able to compare the impact of the convolutions.

The accuracy of the hand-made model was 83%. Much worse than the 96% that we got from the VGG16 model.

--

--