Is it Really an Egg Roll?
Yesterday, I taught a computer vision workshop. In the section about image classification, I had the attendees train a deep learning model that could correctly classify images of the letters A, B, C, E, and H. To build the model, they started with another previously trained model. This is known as "transfer learning." More specifically, transfer learning is a method where a model developed for a task is reused as the starting point for a model on a second task. It is an optimization that allows rapid progress or improved performance when modeling the second task.
One of the models that could be used as a starting point had been trained on images of Chinese food. One of the attendees asked if their newly trained model for classifying letters could correctly classify images of Chinese food as well, even though they had not included any pictures of Chinese food in their data set, since they were starting with a model that had been trained on images of Chinese food. It makes sense, right? I explained that the pre-trained model of Chinese food they started with was simply the weights and biases that were produced during that initial training and did not include any of the classes - the names of each type of food. By starting with those weights and biases, training the letter classification model could be sped up and optimized.
To test what I had said, the student did the following:
1. Downloaded 5 pictures of egg rolls and labeled them as egg rolls. There were already 100 pictures of each of the letters in the data set.
2. Retrained his model to now include 6 classes: A, B, C, E, H, and egg roll.
3. Performed inference using a picture of an egg roll.
What the student found was that when performing the inference using an image of an egg roll, the image was correctly classified as an egg roll with high confidence. Since there were so few images of egg rolls in comparison with the letters, didn't that mean some knowledge of the shape of an egg roll was being conveyed from the pre-trained model?
I had the student conduct the following experiment to test his hypothesis. He found a picture of a hot dog, which looks similar to an egg roll, and performed inference on it using his newly trained model. It was classified as an egg roll with high confidence. It made sense that it would be classified as an egg roll since it has a similar shape, but the confidence level of the prediction was very high. It should not have been. Was it possible that the hot dog was classified as an egg roll because it was the only thing in his data set that looked similar to a hot dog? He then tried using an image of a dog, something that did not look similar to an egg roll. To his surprise, it classified the dog as an egg roll with high confidence. Looking at the heat map that showed which part of the image the model used to determine that it was an egg roll, he discovered that it was the dog's tail that was causing the classification. Clearly the dog should not have been classified as an egg roll, rather it should not have been classified as anything. He tried another picture of dog, one without a long tail, it was not classified as anything. Thus, the conclusion was that the hot dog and dog with a long tail were being classified as an egg roll because they looked more like an egg roll than a letter, not because the model was using previous information about the characteristics of egg rolls.
This simple experiment underscores a very important principle in data science. A model can predict something with very high confidence, but be completely incorrect. Thus, high confidence during inference does not always mean that the model is good. Conversely, low confidence does not always mean the model is not good. We may just need more training data or to modify hyperparameters. We must always go back and do sanity checks on our results instead of just taking it for granted that our model is good or bad.
Be sure to follow me on Twitter (@pacejohn) or visit my blog here!
Thanks to Jason Brownlee for the definitions of transfer learning.