A Guide to TensorFlow Callbacks

TensorFlow callbacks are an essential part of training deep learning models, providing a high degree of control over many aspects of your model training.

4 years ago   •   8 min read

By Keshav Aggarwal

If you are building deep learning models, you may need to sit for hours (or even days) before you can see any real results. You may need to stop model training to change the learning rate, push training logs to the database for future use, or show the training progress in TensorBoard. It seems that we may need to do a lot of work to achieve these basic tasks—that's where TensorFlow callbacks come into the picture.

In this article we 'll cover the details, usage, and examples of TensorFlow callbacks. The outline of this article is as follows:

  • What's a callback function?
  • When callbacks are triggered
  • Available callbacks in TensorFlow 2.0
  • Conclusion

You can also run the full code on the ML Showcase.

Bring this project to life

What's a Callback Function?

Simply put, callbacks are the special utilities or functions that are executed during training at given stages of the training procedure. Callbacks can help you prevent overfitting, visualize training progress, debug your code, save checkpoints, generate logs, create a TensorBoard, etc. There are many callbacks readily available in TensorFlow, and you can use multiple. We will take a look at the different callbacks available along with examples of their use.

When a Callback is Triggered

Callbacks are called when a certain event is triggered. There are a few types of events during training that can lead to the trigger of a callback, such as:
on_epoch_begin: as the name suggests, this event is triggered when a new epoch starts.
on_epoch_end: this is triggered when an epoch ends.
on_batch_begin: this is triggered when a new batch is passed for training.
on_batch_end: when a batch is finished with training.
on_train_begin: when the training starts.
on_train_end: when the training ends.

To use any callback in the model training you just need to pass the callback object in the model.fit call, for example:

model.fit(x, y, callbacks=list_of_callbacks)

Available Callbacks in TensorFlow 2.0

Let’s take a look at the callbacks which are available under the tf.keras.callbacks module.

1. EarlyStopping

This callback is used very often. This allows us to monitor our metrics, and stop model training when it stops improving. For example, assume that you want to stop training if the accuracy is not improving by 0.05; you can use this callback to do so. This is useful in preventing overfitting of a model, to some extent.

tf.keras.callbacks.EarlyStopping(monitor='val_loss', 
                                min_delta=0, 
                                patience=0, 
                                verbose=0, 
                                mode='auto', 
                                baseline=None, 
                                restore_best_weights=False)

monitor: the names of the metrics we want to monitor.
min_delta: the minimum amount of improvement we expect in every epoch.
patience: the number of epochs to wait before stopping the training.
verbose: whether or not to print additional logs.
mode: defines whether the monitored metrics should be increasing, decreasing, or inferred from the name; possible values are 'min', 'max', or 'auto'.
baseline: values for the monitored metrics.
restore_best_weights: if set to True, the model will get the weights of the epoch which has the best value for the monitored metrics; otherwise, it will get the weights of the last epoch.

The EarlyStopping callback is executed via the on_epoch_end trigger for training.

2. ModelCheckpoint

This callback allows us to save the model regularly during training. This is especially useful when training deep learning models which take a long time to train. This callback monitors the training and saves model checkpoints at regular intervals, based on the metrics.

tf.keras.callbacks.ModelCheckpoint(filepath, 
                                     monitor='val_loss', 
                                     verbose=0, 
                                     save_best_only=False,
                                     save_weights_only=False, 
                                     mode='auto', 
                                     save_freq='epoch')

filepath: path for saving the model. You can pass the file path with formatting options like model-{epoch:02d}-{val_loss:0.2f}; this saves the model with the mentioned values in the name.
monitor: name of the metrics to monitor.
save_best_only: if True, the best model will not be overridden.
mode: defines whether the monitored metrics should be increasing, decreasing, or inferred from the name; possible values are 'min', 'max', or 'auto'.
save_weights_only: if True, only the weights of the models will be saved. Otherwise the full model will be saved.
save_freq: if 'epoch', the model will be saved after every epoch. If an integer value is passed, the model will be saved after the integer number of batches (not to be confused with epochs).

The ModelCheckpoint callback is executed via the on_epoch_end trigger of training.

3. TensorBoard

This is one of the best callbacks if you want to visualize the training summary for your model. This callback generates the logs for TensorBoard, which you can later launch to visualize the progress of your training. We will cover the details for TensorBoard in a separate article.

> tf.keras.callbacks.TensorBoard(log_dir='logs',
                                 histogram_freq=0, 
                                 write_graph=True, 
                                 write_images=False,    
                                 update_freq='epoch', 
                                 profile_batch=2, 
                                 embeddings_freq=0,    
                                 embeddings_metadata=None, 
                                 **kwargs)

For now we will see only one parameter, log_dir, which is the path of the folder where you need to store the logs. To launch the TensorBoard you need to execute the following command:

tensorboard --logdir=path_to_your_logs

You can launch the TensorBoard before or after starting your training.

TensorBoard

The TensorBoard callback is also triggered at on_epoch_end.

4. LearningRateScheduler

This callback is handy in scenarios where the user wants to update the learning rate as training progresses. For instance, as the training progresses you may want to decrease the learning rate after a certain number of epochs. The LearningRateScheduler will let you do exactly that.

tf.keras.callbacks.LearningRateScheduler(schedule, verbose=0)

schedule: this is a function that takes the epoch index and returns a new learning rate.
verbose: whether or not to print additional logs.

Below is an example of how to reduce the learning rate after three epochs.

Function to pass to the 'schedule' parameter for the LearningRateScheduler callback

As you can see in the output below, after the fourth epoch the learning rate has been reduced. verbose has been set to 1 to keep tabs on the learning rate.

In epoch 5 learning rate drops to 0.0002 from 0.002 

This callback is also triggered at on_epoch_end.

5. CSVLogger

As the name suggests, this callback logs the training details in a CSV file. The logged parameters are epoch, accuracy, loss, val_accuracy, and val_loss. One thing to keep in mind is that you need to pass accuracy as a metric while compiling the model, otherwise you will get an execution error.

tf.keras.callbacks.CSVLogger(filename, 
                             separator=',', 
                             append=False)

The logger accepts the filename, separator, and append as parameters. append defines whether or not to append to an existing file, or write in a new file instead.

The CSVLogger callback is executed via the on_epoch_end trigger of training. So when an epoch ends, the logs are put into a file.

6. LambdaCallback

This callback is required when you need to call some custom function on any of the events, and the provided callbacks do not suffice. For instance, say you want to put your logs into a database.

tf.keras.callbacks.LambdaCallback(on_epoch_begin=None, 
                                  on_epoch_end=None, 
                                  on_batch_begin=None, 
                                  on_batch_end=None,    
                                  on_train_begin=None, 
                                  on_train_end=None, 
                                  **kwargs)

All the parameters of this callback expect a function which takes the arguments specified here:
on_epoch_begin and on_epoch_end: epoch, logs
on_batch_begin and on_batch_end: batch, logs
on_train_begin and on_train_end: logs

Let’s see an example:

Function to put logs in a file at end of a batch

This callback will put the logs into a file after a batch is processed. The output which you can see in the file is:

Logs generated

This callback is called for all the events, and executes the custom functions based on the parameters passed.

7. ReduceLROnPlateau

This callback is used when you want to change the learning rate when the metrics have stopped improving. As opposed to LearningRateScheduler, it will reduce the learning based on the metric (not epoch).

tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', 
                                     factor=0.1, 
                                     patience=10, 
                                     verbose=0, 
                                     mode='auto',    
                                     min_delta=0.0001, 
                                     cooldown=0, 
                                     min_lr=0, 
                                     **kwargs)

Many of the parameters are similar to the EarlyStoppingCallback, so let's focus on those that are different.
monitor, patience, verbose, mode, min_delta: these are similar to EarlyStopping.
factor: the factor by which the learning rate should be decreased (new learning rate = old learning rate * factor).
cooldown: the number of epochs to wait before restarting the monitoring of the metrics.
min_lr: the minimum bound for the learning rate (the learning rate can’t go below this).

This callback is also called at the on_epoch_end event.

8. RemoteMonitor

This callback is useful when you want to post the logs to an API. This callback can also be mimicked using LambdaCallback.

tf.keras.callbacks.RemoteMonitor(root='http://localhost:9000',                
                                   path='/publish/epoch/end/', 
                                   field='data',
                                   headers=None, 
                                   send_as_json=False)

root: this is the URL.
path: this is the endpoint name/path.
field: this is the name of the key which will have all the logs.
header: the header which needs to be sent.
send_as_json: if True, the data will be sent in JSON format.

For example:

Callback

To see the callback working, you need an endpoint hosted on the localhost:8000. You can use Node.js to do this. Save the code in the file server.js:

Then start the server by typing node server.js (you should have node installed). At the end of the epoch you will see the log in the node console. If the server is not running then you will receive a warning at the end of the epoch.

This callback is also called at the on_epoch_end event.

9. BaseLogger & History

These two callbacks are automatically applied to all Keras models. The history object is returned by model.fit, and contains a dictionary with the average accuracy and loss over the epochs. The parameters property contains the dictionary with the parameters used for training (epochs, steps, verbose). If you have a callback for changing the learning rate, then that will also be part of the history object.

Output of model_history.history

BaseLogger accumulates an average of your metrics across epochs. So the metrics you see at the end of the epoch are an average of all the metrics over all the batches.

10. TerminateOnNaN

This callback terminates the training if the loss becomes NaN.

tf.keras.callbacks.TerminateOnNaN()

Conclusion

You can use any of these callbacks as they suit your needs. It's often good (or even necessary) to use multiple callbacks, like TensorBoard for monitoring progress, EarlyStopping or LearningRateScheduler to prevent overfitting, and ModelCheckpoint to save your training progress.

Remember that you can run the code for all of the callbacks available in tensorflow.keras for free on Gradient. I hope this helps you in training your model.

Happy Deep Learning.

Add speed and simplicity to your Machine Learning workflow today

Get startedContact Sales

Spread the word

Keep reading