The misguided intuition I had to unlearn to come to grips with modern machine learning
By Charles Fisher, Founder and CEO, Unlearn.AI
When my postdoc advisor introduced me to the field of deep learning in 2012, I was extremely skeptical. Prior to that fateful moment, I had it burned into my brain that smaller models were preferable to larger ones. That larger models were prone to overfitting. And that interpretable models were more reliable than black box predictive models. It turns out that everything I thought I knew about machine learning was wrong. The intuition about modeling that I had taken from my statistics textbooks was misguided.
This blog post explains three key concepts that I misunderstood; well, not me specifically — in some sense the whole field misunderstood these concepts. In fact, I think that many researchers in machine learning-adjacent fields continue to misunderstand these concepts and this holds them back from adopting new (and better) methods. Hence this post.
Although it’s an oversimplification, I think it’s useful to think of the field of predictive modeling as going through three different eras.
First, there was the era of small specialized models that lasted from the time of Gauss in the early 1800s all the way until the 2000s. That’s a long time, which is partly why the intuition developed from this era is so widespread and difficult to move on from. This is the era of linear and logistic regression, decision trees, and support vector machines. Models with a relatively small number of parameters that are trained to predict a specific outcome and have some built-in interpretability.
Next, starting in the 2010s, came the era of big specialized models. This was the era in which deep neural networks came to the forefront of machine learning research. The evolution was precipitated due to both the development of new modeling techniques, but perhaps more importantly by the development of sufficiently powerful computers and software frameworks. These models were much larger than models from the prior era, often having millions of parameters, but they were nonetheless performing really well! But, in order to get these large, specialized models to perform well one had to collect a huge amount of data on the desired predictive task.
Now, we are living in the era of big general models. This is the era of generative AI and transformers; gigantic models trained to solve many different tasks at the same time. Remarkably, these models can learn to solve tasks they’ve never seen any data on at all! In a sense, the current zeitgeist is that scale is all you need. Training really big models on really big general datasets can sometimes unlock something almost like magic.
The move from small specialized models to big specialized models in the 2010s came with a significant jump in capabilities. Likewise, the move from big specialized models to big general models that we are currently living through is providing another significant jump in capabilities. Despite this, many people in machine learning-adjacent fields continue to express skepticism about these new types of models. So, here are three concepts I had to unlearn before I could hop aboard the machine learning train.
Old: The parameters in a model should be interpretable.
New: Forget trying to interpret the model parameters and just learn to predict.
Old: Models with lots of parameters are likely to overfit.
New: Adding more parameters to a model often makes it generalize better.
Old: It’s important to train a model on data relevant to its intended use.
New: It’s better to train a model to do many different tasks than any one specific task.
Let me go through each of these in a bit more detail.
Old: The parameters in a model should be interpretable.
New: Forget trying to interpret the model parameters and just learn to predict.
In traditional statistical modeling, it’s common to think of each parameter in a model as having a clear interpretation. For example, a coefficient in a linear regression converts a change in the independent variable (x) into a change in the dependent variable (y). If we change x by one standard deviation, how much will y change?
Modern machine learning models, however, have a large number of parameters and many of these parameters are actually helping the model to train, not helping it to predict, which makes interpreting individual parameters next to impossible. What do I mean by this?
To train a neural network, we use an algorithm called gradient descent that takes small steps in the parameter space in the direction that minimizes the model's prediction error. Here’s a metaphor. Imagine that you’re walking across a landscape trying to get to its lowest point, but you’re only allowed to move north, south, east, or west. What happens if you encounter an obstacle like a cliff? You’ll have to walk all the way around the obstacle, if it’s even possible at all. But, if you could also move vertically (say, you have a rope) then it’s easy, you just repel down the cliff and you’re on your way. The ability to move in a new direction allows you to more easily go around obstacles. Adding more parameters to a neural network does the same thing, it opens up new directions for our gradient descent algorithm to travel so that it can more easily go around obstacles and find a minimum of the prediction error.
Here’s an interesting experimental finding that illustrates the phenomena. We randomly initialize the parameters of a big neural network and train it to a small prediction error. It turns out that we can often prune many of the parameters of the network (i.e., set them to zero) at this point without really hurting its performance. However, if we take the pruned network, randomly re-initialize the non-zero parameters, and then try to re-train the model — it fails to learn as well as it did the first time! Even though it could find the exact same solution, in principle, it can’t reach it anymore. Lots of the model parameters are necessary during training only, they are only needed to help the model learn, not to help it make predictions.
Old: Models with lots of parameters are likely to overfit.
New: Adding more parameters to a model often makes it generalize better.
Models with lots of parameters tend to overfit, right? Wrong! Adding more parameters to a model can actually make it generalize to new data better than a smaller model.
Most work in classical statistical learning theory focused on the regime in which there are fewer model parameters than samples. As one adds parameters to the model, the model does a better job of capturing the data it's being trained on and the training set prediction error decreases. The prediction error on a test dataset often decreases too, but only to a point. The prediction error on the test dataset starts to increase, and it explodes as the number of parameters in the model approaches the number of samples in the training set.
But, what happens if we don’t stop when the number of parameters in the model approaches the number of samples in the training set? What if we just keep adding parameters and explore the regime in which the model has more parameters than we have samples? The prediction error in the training set goes to zero and stays there, because the model can interpolate all of the samples in the training set. But the prediction error on the test dataset starts to come back down! In some cases, it keeps decreasing until the test set performance of a really big model is actually much better than the best small model. This is called “double descent”.
Here’s a graph (from the same postdoc advisor who first introduced me to deep learning) illustrating double descent.
Jason W. Rocks and Pankaj Mehta: Memorizing without overfitting: Bias, variance, and interpolation in overparameterized models. Phys. Rev. Research 4, 013201 – Published 15 March 2022. https://doi.org/10.1103/PhysRevResearch.4.013201
Of course, the next natural question is why does double descent happen? That’s an active area of theoretical research, and it’s beyond the scope of this blog to go through it. What’s important is to come to terms with the fact that double descent is real, it’s widely observed in many empirical studies of different modeling problems, and there are a variety of explanations for it.
It’s interesting that big models often generalize better than smaller ones, but let me end this section with another interesting observation. Earlier, I mentioned that it’s possible to prune a big model by setting some of its parameters to zero without substantially decreasing its performance. However, as one starts to prune the parameters in order to make the model smaller the first place you start to lose performance is on groups that are underrepresented in the training set (Hooker et al. Characterising Bias in Compressed Models). That is, trying to take a big model and make it smaller to gain interpretability actually amplifies bias!
Old: It’s important to train a model on data relevant to its intended use.
New: It’s better to train a model to do many different tasks than any one specific task.
The last thing to go is the entire concept of a specialized model and, with it, most of the intuition one gains from classical statistical learning theory. Consider the following thought experiment.
Imagine you want to train a model to generate prose in the style of Shakespear. You have to choose between the following two approaches.
Option A: Train a generative language model on the combined text of everything ever written by Shakespear, but nothing else.
Option B: Train a generative language model on all of the text you can get your hands on as long as it’s not written by Shakespear, then provide the trained model with 1 example of something written by Shakespear.
Which of these is likely to perform better? Probably Option B.
There are lots of research areas in modern machine learning that are related to this phenomenon such as model pretraining (i.e., training a model on a large compendium of different tasks before fine tuning it for a specific task), transfer learning (i.e., training a model on one task, then fine tuning it on a different task), few shot learning (i.e., training a model that’s capable of learning to perform new tasks from a few examples), and zero shot learning (i.e., training a model that’s capable of solving new tasks even if it’s never seen examples of those tasks before). This is the frontier and future of machine learning research.
As a practical example, when Google wanted to train a large language model to encode clinical knowledge they didn’t simply train a generative language model on medical text; rather, they started with a large language model (called PaLM, with 540B parameters) trained on a diverse collection of text data from the internet and books and the fine tuned it to improve performance on medical knowledge to create Med-PaLM (Singhal et al, Large Language Models Encode Clinical Knowledge).
To be honest, I haven’t yet developed a deep intuition for why this works so well. Currently, I’m simply at the stage of acceptance. At a high level, it makes sense that tasks from different domains likely share some common structure.
For example, if we wanted to build a model to predict disease progression in patients with Alzheimer’s disease, it could make sense to also train it to predict disease progression in patients with related diseases like Parkinson’s because there’s a lot of shared biology. But, one can take this logic even further and suggest that it could make sense to train the model to predict the evolution of any system of differential equations because, even if there’s no shared biology, there’s shared mathematical structure in the evolution of dynamical systems. Since these shared structures are buried in complex systems that aren’t intuitive to humans, it’s best to give all of the data to a really big model and let gradient descent do its job.
Even though many of the methods used in modern machine learning rely on concepts that conflict with intuitions from classical statistical learning theory, there is no denying the unreasonable effectiveness of training big neural networks by gradient descent. In order to come to terms with these new results, I’ve had to throw out and replace my old intuition.
Old intuition:
The parameters in a model should be interpretable.
Models with lots of parameters are likely to overfit.
It’s important to train a model on data relevant to its intended use.
New intuition:
Adding more parameters to a model often makes it generalize better.
Forget trying to interpret the model parameters and just learn to predict.
It’s better to train a model to do many different tasks than any one specific task.
Deep learning has seen slower adoption in the medical field than other areas. I think one reason for that is, perhaps counterintuitively, the prevalence of predictive modeling in medicine. Lots of people work on clinical prediction models, and most of those people cut their teeth in the era of statistical learning theory. They’re still suffering from the same affliction I was—a misguided intuition that hasn’t yet adapted to the last decade of astounding results in machine learning research. It took me years to come to grips with the new reality. So, it’s vastly too much to hope that a single blog post could convince anyone to adjust their thinking, but I do hope that the deep learning skeptics in the field at least find these ideas intriguing enough to spend some time researching these concepts. I believe that medicine is the field most likely to benefit from the move to big general models, and it would be a shame if outdated intuition held back progress that could save and improve countless lives.