If you are building deep learning models, you may need to sit for hours (or even days) before you can see any real results. You may need to stop model training to change the learning rate, push training logs to the database for future use, or show the training progress in TensorBoard. It seems that we may need to do a lot of work to achieve these basic tasks—that's where TensorFlow callbacks come into the picture.
In this article we 'll cover the details, usage, and examples of TensorFlow callbacks. The outline of this article is as follows:
- What's a callback function?
- When callbacks are triggered
- Available callbacks in TensorFlow 2.0
- Conclusion
You can also run the full code on the ML Showcase.
Bring this project to life
What's a Callback Function?
Simply put, callbacks are the special utilities or functions that are executed during training at given stages of the training procedure. Callbacks can help you prevent overfitting, visualize training progress, debug your code, save checkpoints, generate logs, create a TensorBoard, etc. There are many callbacks readily available in TensorFlow, and you can use multiple. We will take a look at the different callbacks available along with examples of their use.
When a Callback is Triggered
Callbacks are called when a certain event is triggered. There are a few types of events during training that can lead to the trigger of a callback, such as:on_epoch_begin
: as the name suggests, this event is triggered when a new epoch starts.on_epoch_end
: this is triggered when an epoch ends.on_batch_begin
: this is triggered when a new batch is passed for training.on_batch_end
: when a batch is finished with training.on_train_begin
: when the training starts. on_train_end
: when the training ends.
To use any callback in the model training you just need to pass the callback object in the model.fit
call, for example:
model.fit(x, y, callbacks=list_of_callbacks)
Available Callbacks in TensorFlow 2.0
Let’s take a look at the callbacks which are available under the tf.keras.callbacks
module.
1. EarlyStopping
This callback is used very often. This allows us to monitor our metrics, and stop model training when it stops improving. For example, assume that you want to stop training if the accuracy is not improving by 0.05; you can use this callback to do so. This is useful in preventing overfitting of a model, to some extent.
tf.keras.callbacks.EarlyStopping(monitor='val_loss',
min_delta=0,
patience=0,
verbose=0,
mode='auto',
baseline=None,
restore_best_weights=False)
monitor
: the names of the metrics we want to monitor.min_delta
: the minimum amount of improvement we expect in every epoch.patience
: the number of epochs to wait before stopping the training.verbose
: whether or not to print additional logs.mode
: defines whether the monitored metrics should be increasing, decreasing, or inferred from the name; possible values are 'min'
, 'max'
, or 'auto'
.baseline
: values for the monitored metrics.restore_best_weights
: if set to True
, the model will get the weights of the epoch which has the best value for the monitored metrics; otherwise, it will get the weights of the last epoch.
The EarlyStopping
callback is executed via the on_epoch_end
trigger for training.
2. ModelCheckpoint
This callback allows us to save the model regularly during training. This is especially useful when training deep learning models which take a long time to train. This callback monitors the training and saves model checkpoints at regular intervals, based on the metrics.
tf.keras.callbacks.ModelCheckpoint(filepath,
monitor='val_loss',
verbose=0,
save_best_only=False,
save_weights_only=False,
mode='auto',
save_freq='epoch')
filepath
: path for saving the model. You can pass the file path with formatting options like model-{epoch:02d}-{val_loss:0.2f}
; this saves the model with the mentioned values in the name.monitor
: name of the metrics to monitor.save_best_only
: if True
, the best model will not be overridden.mode
: defines whether the monitored metrics should be increasing, decreasing, or inferred from the name; possible values are 'min'
, 'max'
, or 'auto'
.save_weights_only
: if True
, only the weights of the models will be saved. Otherwise the full model will be saved.save_freq
: if 'epoch'
, the model will be saved after every epoch. If an integer value is passed, the model will be saved after the integer number of batches (not to be confused with epochs).
The ModelCheckpoint
callback is executed via the on_epoch_end
trigger of training.
3. TensorBoard
This is one of the best callbacks if you want to visualize the training summary for your model. This callback generates the logs for TensorBoard, which you can later launch to visualize the progress of your training. We will cover the details for TensorBoard in a separate article.
> tf.keras.callbacks.TensorBoard(log_dir='logs',
histogram_freq=0,
write_graph=True,
write_images=False,
update_freq='epoch',
profile_batch=2,
embeddings_freq=0,
embeddings_metadata=None,
**kwargs)
For now we will see only one parameter, log_dir
, which is the path of the folder where you need to store the logs. To launch the TensorBoard you need to execute the following command:
tensorboard --logdir=path_to_your_logs
You can launch the TensorBoard before or after starting your training.
The TensorBoard callback is also triggered at on_epoch_end
.
4. LearningRateScheduler
This callback is handy in scenarios where the user wants to update the learning rate as training progresses. For instance, as the training progresses you may want to decrease the learning rate after a certain number of epochs. The LearningRateScheduler
will let you do exactly that.
tf.keras.callbacks.LearningRateScheduler(schedule, verbose=0)
schedule
: this is a function that takes the epoch index and returns a new learning rate.verbose
: whether or not to print additional logs.
Below is an example of how to reduce the learning rate after three epochs.
As you can see in the output below, after the fourth epoch the learning rate has been reduced. verbose
has been set to 1
to keep tabs on the learning rate.
This callback is also triggered at on_epoch_end
.
5. CSVLogger
As the name suggests, this callback logs the training details in a CSV file. The logged parameters are epoch
, accuracy
, loss
, val_accuracy
, and val_loss
. One thing to keep in mind is that you need to pass accuracy
as a metric while compiling the model, otherwise you will get an execution error.
tf.keras.callbacks.CSVLogger(filename,
separator=',',
append=False)
The logger accepts the filename
, separator
, and append
as parameters. append
defines whether or not to append to an existing file, or write in a new file instead.
The CSVLogger
callback is executed via the on_epoch_end
trigger of training. So when an epoch ends, the logs are put into a file.
6. LambdaCallback
This callback is required when you need to call some custom function on any of the events, and the provided callbacks do not suffice. For instance, say you want to put your logs into a database.
tf.keras.callbacks.LambdaCallback(on_epoch_begin=None,
on_epoch_end=None,
on_batch_begin=None,
on_batch_end=None,
on_train_begin=None,
on_train_end=None,
**kwargs)
All the parameters of this callback expect a function which takes the arguments specified here:on_epoch_begin
and on_epoch_end
: epoch, logson_batch_begin
and on_batch_end
: batch, logson_train_begin
and on_train_end
: logs
Let’s see an example:
This callback will put the logs into a file after a batch is processed. The output which you can see in the file is:
This callback is called for all the events, and executes the custom functions based on the parameters passed.
7. ReduceLROnPlateau
This callback is used when you want to change the learning rate when the metrics have stopped improving. As opposed to LearningRateScheduler
, it will reduce the learning based on the metric (not epoch).
tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss',
factor=0.1,
patience=10,
verbose=0,
mode='auto',
min_delta=0.0001,
cooldown=0,
min_lr=0,
**kwargs)
Many of the parameters are similar to the EarlyStoppingCallback
, so let's focus on those that are different.monitor
, patience
, verbose
, mode
, min_delta
: these are similar to EarlyStopping
.factor
: the factor by which the learning rate should be decreased (new learning rate = old learning rate * factor).cooldown
: the number of epochs to wait before restarting the monitoring of the metrics.min_lr
: the minimum bound for the learning rate (the learning rate can’t go below this).
This callback is also called at the on_epoch_end
event.
8. RemoteMonitor
This callback is useful when you want to post the logs to an API. This callback can also be mimicked using LambdaCallback
.
tf.keras.callbacks.RemoteMonitor(root='http://localhost:9000',
path='/publish/epoch/end/',
field='data',
headers=None,
send_as_json=False)
root
: this is the URL.path
: this is the endpoint name/path.field
: this is the name of the key which will have all the logs.header
: the header which needs to be sent.send_as_json
: if True
, the data will be sent in JSON format.
For example:
To see the callback working, you need an endpoint hosted on the localhost:8000. You can use Node.js to do this. Save the code in the file server.js:
Then start the server by typing node server.js
(you should have node installed). At the end of the epoch you will see the log in the node console. If the server is not running then you will receive a warning at the end of the epoch.
This callback is also called at the on_epoch_end
event.
9. BaseLogger & History
These two callbacks are automatically applied to all Keras models. The history
object is returned by model.fit
, and contains a dictionary with the average accuracy and loss over the epochs. The parameters
property contains the dictionary with the parameters used for training (epochs
, steps
, verbose
). If you have a callback for changing the learning rate, then that will also be part of the history object.
BaseLogger
accumulates an average of your metrics across epochs. So the metrics you see at the end of the epoch are an average of all the metrics over all the batches.
10. TerminateOnNaN
This callback terminates the training if the loss becomes NaN
.
tf.keras.callbacks.TerminateOnNaN()
Conclusion
You can use any of these callbacks as they suit your needs. It's often good (or even necessary) to use multiple callbacks, like TensorBoard
for monitoring progress, EarlyStopping
or LearningRateScheduler
to prevent overfitting, and ModelCheckpoint
to save your training progress.
Remember that you can run the code for all of the callbacks available in tensorflow.keras
for free on Gradient. I hope this helps you in training your model.
Happy Deep Learning.