Checkpointing in TensorFlow

Follow this guide to learn how to directly monitor and checkpoint your models during the training process!

3 years ago   •   11 min read

By Fortune Adekogbe
Table of contents

While working on a machine learning problem some months back, I realized that I needed a better way to save the best version of a model during training. This led me to find a way to save a model when specific conditions are satisfied. In this article, I will be discussing what checkpointing is and how it can be implemented using TensorFlow.


  • Familiarity with Python programming
  • Understanding of basic machine learning terminologies
  • Basic familiarity with TensorFlow


Checkpointing generally involves saving the state of a system so that at a later time it can be restored for use. In this way, information that would have been otherwise lost is retained and reloaded for use later. Yes, it is quite similar to just saving a file but as a term is more frequently used when computational systems are saved under naturally complicated conditions.

In machine learning, checkpointing involves saving the current state of a model (architecture, weight, optimizer and so on) so that it can be reloaded for use later on.

Why Checkpointing?

  • Interrupted Training Loops: When a training sequence gets intentionally or unintentionally terminated, the entire progress can be lost. This can be costly in terms of the cost of computing and development time. Checkpointing helps to store the model for later use when the training can be continued.
  • Saving optimal models for production: During training, it is usually the case that the final model is not the best in terms of the desired metric. This might lead to one trying to identify the best epoch and then training for only that number of epochs. This however is an important time that will be wasted given the stochastic nature of these models. Checkpointing helps with saving the model when the desired condition is satisfied, making the training and saving of models a more straightforward process.


Tensorflow is an open-source platform for end-to-end machine learning. Though it has support for multiple languages, the most popular option for machine learning is Python and that is what we will be using in this article. Tensorflow has a checkpointing method under tf.train.checkpoint but in this piece, we will be using the ModelCheckpoint callback method under the Keras module.

ModelCheckpoint and its Arguments

Callbacks are pieces of programs or scripts that are executed when a particular condition is satisfied. In our case, the condition is to save the model after some training epochs. The keras.callbacks.ModelCheckpoint method implements the checkpointing but requires us to specify some arguments. We will discuss these briefly.

filepath: This is the only required parameter and it refers to the location of the stored checkpoint. Required here implies that once this is specified, the checkpointing callback can be used with the other arguments already set to their default values.

monitor: This parameter is a string that determines the metric to be monitored by the checkpointing callback during training. It is set to val_loss by default but can be tweaked to be whatever best fits the scenario. To monitor training metrics, the val_ prefix is excluded. Also, if the desired metric's name is user-defined, that should be passed in as the parameter.

verbose: this argument is an integer that is used to set the logging level and is either 0 or 1 representing none or all. It has a default value of 0.

save_only_best_model: this argument takes in a boolean value to specify that only the best model should be saved. This is based on whatever metric it was set to monitor. The default for this is False and so if not changed, the model will be saved at every instance when it should be to the same file path. If the goal is to save the multiple models individually, care has to be taken as the model will overwrite the specified file at filename with every epoch if the filename is not formatted to reflect the epochs and differentiate the files.

mode: The mode parameter is a string that can be set to one of {'auto', 'min', 'max'}. 'min' and 'max' are respectively set for metrics that need to be minimised (eg. loss) and maximized (e.g. accuracy). When set to 'auto' which is its default, metrics that represent accuracy or F-score are maximized and all other ones are minimized.

save_weights_only: this parameter accepts a boolean value that determines whether just the model’s weights should be saved or not. If True, the model.save_weights method will be called instead of the function.

save_freq: this parameter accepts the string 'epoch' to save the model after every epoch or an integer to represent the frequency of the saving in terms of the number of batches. The default for this is 'epoch' since saving by batches could make the monitored metric less reliable.

options: This is an optional tf.train.CheckpointOptions object that is used when save_weights_only is True. When it is false, the tf.saved_model.SaveOptions object is used. These objects provide additional options for checkpointing or saving a model that is useful in distributed settings.

Now we know what checkpointing is, why it is necessary and the parameters required to define a checkpoint in TensorFlow. But how is it really used in practice? To answer this, we will be building a model to classify IMDB movie reviews into positive and negative classes.

Installing and importing required libraries

$pip install tensorflow tensorflow_datasets
import numpy as np
import tensorflow as tf
from tensorflow import keras

import tensorflow_datasets as tfds

In the code snippet above, we first install tensorflow and tensorflow_datasets with pip via the command line in case they are not already installed in your development environment. Next, we import numpy and tensorflow. We also import keras from tensorflow as we will be making several calls from that specific module. Finally, we import tensorflow_datasets and disable its progress bar.

Creating a visualization function

import matplotlib.pyplot as plt

def plot_graphs(history, metric):
  plt.plot(history.history['val_'+metric], '')
  plt.legend([metric, 'val_'+metric])

Next, we import matplotlib.pyplot to create a function that will aid visualization later on. The function takes in a history object and a desired metric and plots the train and validation data for that metric against the number of epochs. Labels are also added for the x and y axes.

Bring this project to life

Load IMDB reviews dataset

dataset, info = tfds.load('imdb_reviews', with_info=True,
train_dataset, test_dataset = dataset['train'], dataset['test']


Here, we pass in the name of the dataset as a string and two parameters to determine the nature of the information returned. with_info is a boolean that determines whether metadata about the dataset gets returned and as_supervised is a boolean to determine whether the dataset is loaded for supervised learning applications or not.

Next, we separate the dataset into its train and test portions by passing in 'train' and 'test' as keys to the dataset variable. To conclude, we display the description and citation information from the information collected. We also view the element specifications of the train_dataset variable to see the datatypes of its constituents. We can see that it contains two Tensors with tf.string and tf.int64 data types respectively.

Configuring the data pipeline


{'test': <tfds.core.SplitInfo num_examples=25000>,
 'train': <tfds.core.SplitInfo num_examples=25000>,
 'unsupervised': <tfds.core.SplitInfo num_examples=50000>}

First, we use the splits attribute to see the size of the train and test data.This is given as 25000 each.


train_dataset = train_dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(
test_dataset = test_dataset.batch(BATCH_SIZE).prefetch(

for review, label in train_dataset.take(1):
  print('reviews: ', review.numpy()[:3])
  print('\nlabels: ', label.numpy()[:3])
reviews:  [b'This was a big disappointment for me. I think this is the worst Mastroianni-movie ever made. Cosmatos tries too hard to make this movie a masterpiece and that makes this movie a typical "art"-movie. I give 4/10 for this movie.'
 b"This picture for me scores very highly as it is a hugely enjoyable ... These type of films often deserve a cult following:<br /><br />8/10."
 b'Almost too well done... "John Carpenter\'s Vampires" was entertaining, ... "Vampires: Los Muertos" is almost too well done. (I give it 7 of 10)']

labels:  [0 1 1]

In the code snippet above, we set a batch size and a buffer size. The batch size determines the number of samples that are processed at a time. It is passed into the batch method. The buffer size on the other hand is passed in as a parameter to the shuffle method where it determines the scope from which the method initially shuffles. By setting it to 10000, the method initially selects random samples from just 10000 out of the 25000 training samples.

Then, it replaces the selected samples with new samples from the remaining 15000 samples. Doing this iteratively, the entire dataset is divided into batches of the set batch size. It is worthy to note that the shuffle method is used for just the train dataset since the order does not really matter in the test dataset.

Finally, we display a portion of the dataset by taking a batch from the train_dataset and printing the first 3 reviews and labels.

Creating the text encoder

So far, we still have the input data as raw text but it will need to be encoded into numbers for a model to train on it.

encoder = keras.layers.TextVectorization(
encoder.adapt( text, label: text))

First, we declare a vocabulary size as 1000. This means that we are only taking care of encodings for a thousand words while the others are left as unknown words.

Then, we pass this integer into the max_tokens method of the keras.layers.TextVectorization method. Then, we adapt the encoder to our training data. We also use the map method to extract just the text from the train_dataset by adding a lambda function for just that.

vocab = np.array(encoder.get_vocabulary())
array(['', '[UNK]', 'the', 'and', 'a', 'of', 'to', 'is', 'in', 'it', 'i',
       'this', 'that', 'br', 'was', 'as', 'for', 'with', 'movie', 'but'],
encoded_example = encoder(review)[:3].numpy()
array([[ 11,  14,   4, ...,   0,   0,   0],
       [ 11, 433,  16, ...,   0,   0,   0],
       [210, 100,  74, ...,   1,   5, 302]])

Here, we output 20 of the elements in the vocabulary of the encoder. Notice the '[UNK]' string that represents all unknown words. We also output the encoding of three reviews. With the default setting that remained unchanged in the TextVectorization method, this process will not be reversible. However, since we are building a classifier, that will not be a problem.

Building the model

model = keras.Sequential([
    keras.layers.Dense(64, activation='relu'),

The model is defined as a keras.Sequential object. The first layer is the encoder we just built so that input encoding is the first thing that happens. Next is an embedding layer. This layer converts the sequences of word indices into sequences of trainable vectors. Post training, words with similar meanings have the same embeddings.

Next, we add a Long Short Term Memory (LSTM) layer wrapped in a keras.layers.Bidirectional layer. The LSTM processes sequences by iterating through the elements and passing the results from one timestep to the next. Then the bidirectional wrapper propagates the input forward and backwards through the LSTM layer and concatenates the output.

When the LSTM has transformed the sequence to a single vector, the two keras.Layers.Dense layers do final processing to convert the vector to a single logit as the classification output. If this value is greater than 0, the classification is positive; else, it is negative.

All hidden layers were set to have just 64 neurons. This can of course be changed if it improves performance.

To be sure that this works, we parse a text to the untrained model via the model.predict function and print the result as displayed below.

sample_text = ('The movie was cool. The animation and the graphics '
               'were out of this world. I would recommend this movie.')
predictions = model.predict(np.array([sample_text]))
print("Positive" if predictions[0]>0 else "Negative")

As expected, the model wrongly classifies the review.

Compiling the Model

We have come a long way but it was indeed important to see checkpointing used in an actual problem so that its utility can be appreciated. Next, we compile the model and define the checkpointing and early stopping callbacks.


checkpoint = keras.callbacks.ModelCheckpoint("best_model",
stop_early = keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)

In the above code snippet, we compile the model with a BinaryCrossEntropy loss, an Adam optimizer with a learning rate set as 1e-4 and 'accuracy' as a metric.

Next, the checkpoint method is created by instantiating a keras.callbacks.ModelCheckpoint object and passing in a string 'best_model' as the file path to save to as well as setting save_best_only to True. All other parameters are left with their earlier stated default values.

Next, we define the early stopping callback and parse in parameters that make it monitor the validation loss and have a patience value of 5 epochs. This means that the training loop will stop when the validation loss has not increased after 5 epochs.

Training the model

Here we call the fit function parsing in the training data, validation data, number of epochs as 10, validation steps and callbacks. The validation_steps parameter limits the number of test batches we validate to 30. This can be excluded in order to run the validation loop on the entire test set but is a costly process. The callback parameter is what does the trick by taking in both the defined early stopping and checkpointing callbacks as elements in a list. Running this cell trains the model and saves it every time a new best validation loss is obtained.

The results of this training are stored in the history variable and visualized as shown below focusing on accuracy and loss.

history =,
                    callbacks=[stop_early, checkpoint],
Epoch 1/10
391/391 [==============================] - 94s 229ms/step - loss: 0.6162 - accuracy: 0.5985 - val_loss: 0.8097 - val_accuracy: 0.5661
INFO:tensorflow:Assets written to: best_model/assets
INFO:tensorflow:Assets written to: best_model/assets
Epoch 2/10
391/391 [==============================] - 88s 223ms/step - loss: 0.4062 - accuracy: 0.8184 - val_loss: 0.3691 - val_accuracy: 0.8339
INFO:tensorflow:Assets written to: best_model/assets
INFO:tensorflow:Assets written to: best_model/assets
Epoch 3/10
391/391 [==============================] - 89s 225ms/step - loss: 0.3470 - accuracy: 0.8510 - val_loss: 0.3729 - val_accuracy: 0.8547
Epoch 4/10
391/391 [==============================] - 88s 223ms/step - loss: 0.3334 - accuracy: 0.8568 - val_loss: 0.3491 - val_accuracy: 0.8380
INFO:tensorflow:Assets written to: best_model/assets
INFO:tensorflow:Assets written to: best_model/assets
Epoch 5/10
391/391 [==============================] - 89s 226ms/step - loss: 0.3245 - accuracy: 0.8619 - val_loss: 0.3371 - val_accuracy: 0.8479
Epoch 6/10
391/391 [==============================] - 90s 226ms/step - loss: 0.3180 - accuracy: 0.8645 - val_loss: 0.3372 - val_accuracy: 0.8526
Epoch 7/10
391/391 [==============================] - 90s 228ms/step - loss: 0.3174 - accuracy: 0.8658 - val_loss: 0.3275 - val_accuracy: 0.8604
INFO:tensorflow:Assets written to: best_model/assets
INFO:tensorflow:Assets written to: best_model/assets
Epoch 8/10
391/391 [==============================] - 90s 227ms/step - loss: 0.3120 - accuracy: 0.8664 - val_loss: 0.3359 - val_accuracy: 0.8609
Epoch 9/10
391/391 [==============================] - 89s 225ms/step - loss: 0.3111 - accuracy: 0.8681 - val_loss: 0.3378 - val_accuracy: 0.8552
Epoch 10/10
391/391 [==============================] - 90s 226ms/step - loss: 0.3077 - accuracy: 0.8698 - val_loss: 0.3285 - val_accuracy: 0.8562
plt.figure(figsize=(16, 6))
plt.subplot(1, 2, 1)
plot_graphs(history, 'accuracy')
plt.subplot(1, 2, 2)
plot_graphs(history, 'loss')

Load Best Model

loaded_model = keras.models.load_model('best_model')

sample_text = ('The movie was nice. The animation and the graphics '
               'were out of this world. I would recommend this movie.')
predictions = loaded_model.predict(np.array([sample_text]))
print("Positive" if predictions[0][0]>0 else "Negative")

To load the saved best model, we simply parse the file path to the keras.models.load_model method. Next, we predict by parsing a sample text to the model.predict method. We get a positive result that corresponds with our interpretation.

Continue Training with checkpoints

To continue training a loaded model with checkpoints, we simply rerun the function with the callback still parsed. This however overwrites the currently saved best model, so make sure to change the checkpoint file path if this is undesired.

loaded_model = keras.models.load_model('best_model')
new_history =, epochs=20,
                    callbacks=[stop_early, checkpoint],


In this tutorial, you have learnt what checkpointing is, why it is important, the basics of natural language processing in the form of sentiment analysis and how to implement checkpointing in a practical scenario. I hope you found this useful.


Add speed and simplicity to your Machine Learning workflow today

Get startedContact Sales

Spread the word

Keep reading