Model Training Patterns - Transfer Learning

Transfer learning, in short, is incorporating a previously trained model with its weights frozen (or non-trainable) and final layer removed, into a new model to solve a similar (more specialized) problem. The final layer is replaced with the output layer of our specialized task before training is continued.
Some common use cases where transfer learning can be useful include:
  • Image object detection
  • Image style transfer
  • Image generation
  • Text classification
  • Machine translation
  • Bottleneck layer

    Bottleneck layer is the layer that represents the input in the lowest dimensionality space.

    Typically, it's the last layer before a flattening operation.
    This is the last layer of the pre-trained model that we want to load and attach the new custom layers to.
    Following is how we would use a VGG model with Tensorflow for transfer learning with a specialized image data with each example of the shape 150 X 150 X 3:
    vgg_model = tf.keras.application.VGG19(
      # last layer to be loaded is the bottleneck layer
      include_top=False,
      weights='imagenet',
      input_shape=((150, 150, 3))
    )
    
    # 0 trainable parameters in the pre-trained model
    # Feature extraction
    vgg_model.trainable = False
    
    feature_batch = vgg_model(image_batch)
    
    global_avg_layer = tf.keras.layers.GlobalAveragePooling2D()
    feature_batch_avg = global_avg_layer(feature_batch)
    
    prediction_layer = tf.keras.layers.Dense(8, activation='softmax')
    prediction_batch = prediction_layer(feature_batch_avg)
    
    specialized_model = keras.Sequential([
      vgg_model,
      global_avg_layer,
      prediction_layer
    ])
    The embedding analogy
    The bottleneck layer in a pre-trained model can be compared to an embedding.
    Consider an encoder-decoder architecture. A bottleneck layer then acts as an embedding to represent the original data in a lower dimensional representation, which then the decoder, i.e., the custom new network after the bottleneck layer, uses to decode this representation back to its original.
    Transfer learning with TF Hub
    hub_layer = hub.KerasLayer(
      "url_to_model",
      input_shape=[],
      dtype=tf.string,
      # Fine-tuning 
      trainable=True
    
    model = keras.Sequential([
      hub_layer,
      keras.layers.Dense(32, activation='relu'),
      keras.layers.Dense(1, activation='sigmoid')
    ])
    To summarize, transfer learning is doing the following:
  • Learning new things (combination of pre-trained model and the network after the bottleneck layer
  • Learning how to learn new things (with pre-trained model)
  • What happens when we modify the weights of the pre-trained model?
    In the first example above, we have set vgg_model.trainable to False. This is feature extraction. Here we only train the custom layers after the bottleneck layer.
    In the second example, however, we have set trainable to True. This is considered to be fine-tuning the weights of the pre-trained model. We can either update weights for all layers or some layers of the pre-trained model.
    Decide the number of layers to fine-tune
    To decide the number of layers that should be fine-tuned, we can use the following approach:
  • Keep the learning rate low ~ 0.001
  • Keep the number of iterations small
  • Unfreeze from the end with each iteration and monitor the model's loss after training
  • Iteratively unfreeze more layers and continue monitoring the model's loss after training
  • Stop when the first layer is reached or the model's loss becomes plateau.
  • This approach is considered to be progressive fine-tuning.
    Applying Feature extraction vs. Fine tuning
    Feature extraction Fine-tuning
    Smaller dataset (100 to 1000 examples) Larger dataset (~ 1 million examples)
    Prediction task is different than that of the pre-trained model Prediction task is same as or similar to that of the pre-trained model
    Preferred when budget is low Preferred when budget is high as training time and computation cost will be higher with updating weights of pre-trained model

    24

    This website collects cookies to deliver better user experience

    Model Training Patterns - Transfer Learning