Model Training Patterns - Checkpoints

While training on larger datasets, it becomes extremely important to save partially trained models due to training time involved and chances of machine failure. Ideally, we do this at the end of every epoch.

However, it's not just the model we need to save at the end of each epoch but also the state of the model at that point of time.

Model and model's state

A model contains the information that's necessary to create the prediction function.

A model's state can contain more information related to the point of time at which we save the model such as epoch, batch number, learning rate etc.

Saving the full model's state so that training can be resumed at a later point in time is called checkpointing and saved model files are considered as checkpoints.

Checkpointing a model in Keras as follows:

checkpoint_path = '{}/checkpoints/taxi'.format(OUTDIR)
cp_callback = tf.keras.callbacks.ModelCheckpoint(
                checkpoint_path,
                save_weights_only=False,
                verbose=1)

history = model.fit(x_train, y_train,
                    batch_size=64,
                    epochs=3,
                    validation_data=(x_val, y_val),
                    verbose=2,
                    callbacks=[cp_callback])

With this, we fit, eval, checkpoint; fit, eval, checkpoint and so on.

Partially trained models through checkpoints are helpful in a few other ways:

Generalizable

Partially trained models are more generalizable. We can go back and check for the ones that have learned high-level view of the data vs the ones that are overfitting the data.

Early stopping

During model overfit, the validation error increases. To avoid this, we can have a look at the validation error at the end of each epoch and decide if we should continue or do a early stop.

Also, in general, we should train a bit longer before deciding on early stopping as the increase in validation error might be because the model is trying to learn the rare situations in the dataset and it might decrease after a certain time period. So, it's always better to wait before stopping.

Early stopping at the end of every batch can be beneficial too.

Regularization

An alternative to early-stopping and checkpoint selection would be regularization.

Instead of doing 3 splits on the data set (train, validate and test), we do just 2 (train and evaluate). So, we use the entire data set during training.

The evaluation data is used as test data during experimentation phase and as validation data in production.

During experimentation phase, we focus on regularizing the model based on how well it does with training error vs testing error.

During deployment, we train the model with early stopping or checkpoint selection to monitor how well it does on the evaluation data.

Fine-tuning

Checkpoint selection is useful when training a model on freshly arrived data. In such cases, we can re-train the model on new data based on the model state checkpoint a few steps before the training loss starts to become plateau.

Using steps_per_epoch for checkpointing

Epoch: A full training pass over the entire dataset such that each example has been seen once.

1 epoch = (N / batch_size) iterations or steps

So, steps_per_epoch defaults to N / batch_size. However, if we want more granular control, we can adjust it based on requirements for better checkpoints.

Consider that we have 15 million examples and over the time we realized the model converges after viewing 14.3 million examples. So, we decided to save the resource required to process 0.7 million examples. So, how can we do that?

If we were to train for 15 epochs (and checkpoint at the end of each epoch, i.e., number of checkpoints = 15) with batch size of 100, we need 1,50,000 steps_per_epoch.

Thus, to train for 14.3 epochs with batch size of 100, we would need 1,50,000 * (14.3 / 15) = 1,43,000 steps_per_epoch. Here, 14.3 is considered to be a stop point (analogous to an epoch).

When newer data arrives, we train with the steps_per_epoch set to 1,43,000 and then, update the number of training examples and stop point, keeping batch size and number of checkpoints constant so that the steps_per_epoch remains the same.

Keeping the steps per epoch constant this way is considered to create a virtual epoch.

Here is how we do this:

NUM_TRAINING_EXAMPLES = 150,000,000
STOP_POINT = 14.3
BATCH_SIZE = 100
NUM_CHECKPOINTS = 15
steps_per_epoch = (NUM_TRAINING_EXAMPLES / BATCH_SIZE) * (STOP_POINT / NUM_CHECKPOINTS)

cp_callback = tf.keras.callbacks.ModelCheckpoint(
                checkpoint_path,
                save_weights_only=False,
                verbose=1)

history = model.fit(trainds,
                    batch_size=BATCH_SIZE,
                    epochs=NUM_CHECKPOINTS,
                    validation_data=(evalds),
                    steps_per_epoch=steps_per_epoch,
                    verbose=2,
                    callbacks=[cp_callback])

12