Posted on December 17, 2018 by Noon van der Silk | Back to recent posts

How much data do you need to retrain a classifier?

Tags: deep-dive

The code used for this blog-post is here.

A while ago, Gala had a great idea for a little animation: Give people an idea of how much data they need to train a machine learning model, by showing an animation of a model trained on, say, 5, 50, 500, 5,000, 50,000, 500,000 data points. The way I pictured it in my mind was that as you proceed to more and more data, the number of successful classifications gets larger; ultimately tapering off as we get to larger numbers. In my mind I saw it as the standard S-curve:

Quiz: Before reading the rest of this article, have a go at this quick quiz.

With the network and data-set we’re going to be focused on (the TensorFlow transfer learning example using InceptionV3), the expected validation accuracy is 90-95%, when run on the full data set (~700 images per class).

Note that this is using five kinds of flowers as the classes: tulip, rose, daisy, dandelion, sunflower.

Suppose we set aside 70 images per class for validation, in each of the following scenarios:

  • Q1: If we use 50% of the remaining images, what would you expect the validation accuracy to be?
  • Q2: If we only use 5% of the remaining images?
  • Q3: If we only use 3 images in total per class?
  • Q4: What about only 1 image per class!?

Background on Transfer Learning

For this article, we’re focusing on using a pre-trained network and specialising it to our specific (classification) task. This idea is called “Transfer Learnining”.

Here’s a picture that describes the idea:

There’s two steps:

First, you take some “pre-trained” network that is good at a particular task. That’s “1)” in the picture above; some network that is good a recognising trees, houses, and people. It’s been trained for this task and is really good at it.

Secondly, you cut off the final layer, and plug in your new classes that you want to predict. This gives you new parameters to learn (the edges coloured in magenta above). Then, you train your new network on your new data.

The point is, this network should have a lot of “juice” (i.e. learned recognition capability) in the earlier layers, that capture general-purpose structure; and the later layers, that you leave as learnable paremeters, allow the network to specialise to your task.

This is a really powerful idea in deep learning, and, in essence, is incredibly heavily used (from the idea of pre-trained word vectors, to pre-trained image classification networks, as we’re using here).

Let’s dive in

Here’s the plan:

  1. Follow the TensorFlow Transfer Learning Tutorial, but
  2. Create a “hold-out” dataset from the original dataset, say 10%,
  3. Create downsampled datasets, from the new dataset, of the following: 100%, 90%, 80%, 70%, 60%, 50%, 40%, 30%, 20%, 10%, 5%, as well as the 3-image and 1-image sets. Visually:
  1. Throw in some data augmentation, using imgaug,
  2. Train all these experiments,
  3. See how it all goes!

I’ve wrapped this up into a repo so you can reproduce my results: braneshop/how-much-data-experiments. Follow the link and check out the README if you want to run it yourself. It takes a few hours to run everything.

Results

Let’s take a look at the train/validation curves for all the training runs. You can click on the circle next to a given experiment name to show only that. Shift-click will show many at once. First, note that in the “just-1” and “just-3” image case, I hacked the normal training code to not worry about having any validation/test data allocation. As a result, we don’t have them in this graph.

Observations:

The real proof is in the accuracy against the holdout set, though, so let’s see how the models performed in this case. The following graph is the result of running all the models on the holdout images, and taking the maximum predicted value of all the resulting class predictions.

There were many surprising things about this graph for me. Notably, it doesn’t look like the graph that I predicted at the start (note that the axes are flipped), but probably the most shocking thing for me is this:

With just a single image per class, we get 68% accuracy on the holdout set! And with just 3 images, we get 75%! Recall that the holdout set consists of ~70 images per class.

Now, when I ran this initially, I hand-picked the 1-image and 3-images per class. I picked ones that looked pretty good, but I was still amazed at the results! I set up a notebook to run a few different examples of that, and I also evaluated the accuracy with the threshold approach, instead of the argmax; here are a few runs:

seed = 1                    seed = 2                    seed = 3

1 image                     1 image                     1 image
              max: 0.5                    max: 0.53                   max: 0.55
  threshold (0.5): 0.3        threshold (0.5): 0.38       threshold (0.5): 0.39
  threshold (0.8): 0.13       threshold (0.8): 0.18       threshold (0.8): 0.15

3 images                    3 images                    3 images
              max: 0.68                   max: 0.66                   max: 0.67
  threshold (0.5): 0.55       threshold (0.5): 0.51       threshold (0.5): 0.54
  threshold (0.8): 0.23       threshold (0.8): 0.25       threshold (0.8): 0.23

You can see that for randomly-picked images, the argmax sits aroud 53% for 1 image, and 67% for three images.

The point is, to me it’s quite amazing to see that picking the argmax, instead of using the threshold, gives such a radical improvement on withheld data.

I’ve made a notebook on Google Colaboratory, so if you have a Google account, you can run the experiments for yourself (you’ll have to save it as a new copy).

Discussion

For me, this was quite surprising. For one, I didn’t see the improvement at the middle-range of data that I was expecting (50% training data wasn’t much better than 100%, excluding the holdout set). But moreso, just how much juice can be squeezed out of this particular network + a handful of images is amazing.

The main thing this highlights to me is the power of semi-supervised learning, active learning, and One-shot learning.

Specifically, it suggests to me that the best thing to do is to set up a learning+data acquisition process that makes predictions as fast as possible, and constantly learns as it goes:

Overall, the idea of incorporating data into a model as-it-infers isn’t new, but to me, again, what is surprising is just how little data you need to get started; at least in this specific case!

To reiterate differently: If you’re a business wondering how to adopt AI, one plan would be to build this kind of human-in-the-loop system, where any prediction that is, say, better than 50% chance of being right will be useful to humans. Then, have the humans feed this system as they use it, and it will radically improve with just a small number of confirmed answers!

So, to answer the question from the title, we might say “less than you think!”, if you have the right setup to incorporate new data into your system, as you go.

Open questions

Quiz answers

Using the maximum value of the class estimates (argmax):

  • Q1: If we use 50% of the remaining images, what would you expect the validation accuracy to be? Answer: ~91%
  • Q2: If we only use 5% of the remaining images? Answer: ~83%
  • Q3: If we only use 3 images in total per class? Answer: ~67%
  • Q4: What about only 1 image per class!? Answer: ~53%

The animation

Here’s my attempt at the animation (click to view):

The images with the red text above them are the ones where the inferences were wrong. Unfortunately (or fortunately!) it’s not very compelling, because all the accuracies are basically very high!