15
Transfer Learning is the Most Important Tool You Need to Learn
You almost immediately learn about transfer learning when you take the fast.ai course. This is no mistake! Transfer learning gets you started in your deep learning journey quickly, building your first neural network-based classifier in minimal time.
But transfer learning is so much more, and it's surprisingly important for machine learning in science.
From a technical standpoint, transfer learning means taking a trained machine learning model and using it on a new, different, but related task. Let's say you take a classifier trained on ImageNet, like a ResNet. It's effortless these days to download that ResNet with trained weights on ImageNet! Usually, that training takes days to weeks and a significant chunk of money from a research grant.
However, our data isn't ImageNet. Let's say you want to start with a cat/dog classifier. Then transfer learning enables you to utilize the existing weights on the complex ImageNet data that has many more classes and more complexity captured in the weights of the network. Oftentimes, machine learning practitioners will retrain the network with a low learning rate, called fine-tuning, to maintain said complexity of the network, but adjust the weights ever so slightly for their cat/dog classifier.
Science doesn't really care about cat/dog classifiers. We want to scan brains, find black holes and explore the core of the Earth!
This is where the big sibling of transfer learning comes in: Domain adaptation. Domain adaptation posits that an agent trained on a task A can do task B with minimal retraining if A and B are sufficiently similar. Sound familiar? Yes, it's a general version of transfer learning. The underlying idea is that the agents learn a distribution over task A , just like a neural network learns a representation of, i.e. ImageNet. Then the agent or our network can be adapted to work on a task, where the distribution is similar enough for the learned distribution to be "shiftable" towards our goal task.
What does that mean?
How can this help me in science?
ImageNet weights are readily available and capture a lot of complexity in a classification task. Of course, there are larger datasets, but the availability here is vital. In case you're more interested in other tasks, like object identification PASCAL-VOC or COCO are good choices.
Then you can use these pre-trained networks and see whether you can fine-tune them on your data. In fact, pre-trained ImageNet ResNets were used in the Kaggle competition to segment salt in seismic data! And yes, this my conference paper from 2018 was a direct precursor to this work.
Be aware that some networks generalize better than others. VGG16 is known for its general convolutional filters, whereas a ResNet can be tricky, depending on how much data you have available.
Remember that you can always use a pre-trained network in a larger architecture! In the competition above, people included the pre-trained ResNet within a U-Net to perform semantic segmentation instead of classic single image classification!
In science, especially in emerging applications, scientists will publish a proof of concept paper. Making a network work on a specific use case can be a valuable gateway to a general method.
Great scientists will publish their trained model alongside the code to reproduce these results. Frequently, you can use that trained network and fine-tune the network on a new site of a similar problem set. An example from the Focus Group meeting for AI for Natural Disaster Management I attended today is to train a network on an avalanche-prone site where a lot of data is available and then transfer the model to the new complexity captured by the first site.
This is not the generalization that most machine learning scientists strive for. However, it can be a valuable use case for applied scientists in specific areas. It's also great as a proof of concept to obtain grants for more data collection.
The holy grail. The network to rule them all. The GPT-3 to my little corner of science.
When you're in the lucky position to have massive amounts of diverse datasets, you can train a model on all of it. This model ideally captures general abstractions from the data. Then, you fine-tune it to specific areas when needed.
This is, for example, done in active seismic. Many of the data brokers and service companies create large-scale models that can then be fine-tuned to the specific segmentation task a customer has at hand.
These models are particularly exciting to examine in ablation and explainability studies to gain deeper insight into specific domains and how a neural network interprets it.
You could start with the fast.ai course and learn about transfer learning, or maybe read some research papers on the fantastic paperswithcode.
Or maybe you'd like to watch the CVPR 2021 workshop that in part inspired me to write this short look into applications of transfer learning: Data- and Label-efficient Learning in an Imperfect World.
15