Deep Learning Model Explainability with SHAP

In this article, we examine the game theory based approach to explaining outputs of machine learning models: Shapely Additive exPlanations or SHAP. We then demo the technology using sample images in a Gradient Notebook.

a year ago   •   12 min read

By Oreolorun Olu-Ipinlaye
Table of contents

Bring this project to life

Often times, deep learning models are said to be black boxed in nature. Black boxed in the sense that their outputs are difficult to explain or some times simply unexplainable. However, there are some Python libraries which help to provide some sort of explanation to the output of deep learning models. In this article, we will be taking a look at one of those libraries: SHAP.

#  article dependencies
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as Datasets
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import cv2
from tqdm.notebook import tqdm
import seaborn as sns
from torchvision.utils import make_grid
!pip install shap
import shap
if torch.cuda.is_available():
  device = torch.device('cuda:0')
  print('Running on the GPU')
else:
  device = torch.device('cpu')
  print('Running on the CPU')

Model Explainability

Model explainability refers to the process whereby outputs produced by machine learning models are explained in terms of how and which features influence the model's actual output. For instance, consider a random forest model trained to predict house prices; assume that dataset the model was trained on only has 3 features, number of bedrooms, number of bathrooms and size of the living room. Assume the model predicts a house to be worth about $300,000, with model explainability we can derive insights on how much each feature contributes either positively or negatively to the predicted price.

Model Explainability in the Context of Computer Vision

As regards deep learning, computer vision classification tasks in particular, since features are essentially pixels, model explainability helps to identify pixels which contribute negatively or positively to the predicted class.

In this article, the SHAP library will be used for deep learning model explainability. SHAP, short for Shapely Additive exPlanations is a game theory based approach to explaining outputs of machine learning models, more information can be found in its official documentation.

Implementing Deep Learning Model Explainability

In this section, we will be training a convolutional neural network for a classification task before proceeding to derive a insight into why the model classifies an instance of data into a specific class using the SHAP library.

Dataset

The dataset to be used for training purposes as regards this article will be the CIFAR10 dataset. This is a dataset containing 32 x 32 pixel images belonging to 10 distinct classes ranging from airplanes to horses. It can be loaded in PyTorch using the code cell below.

#  loading training data
training_set = Datasets.CIFAR10(root='./', download=True,
                                transform=transforms.ToTensor())

#  loading validation data
validation_set = Datasets.CIFAR10(root='./', download=True, train=False,
                                  transform=transforms.ToTensor())
Label Description
0 Airplane
1 Automobile
2 Bird
3 Cat
4 Deer
5 Dog
6 Frog
7 Horse
8 Ship
9 Truck

Model Architecture

The model architecture as illustrated above is implemented in the following code cell. This is a custom architecture designed purposefully for the sake of this article. This architecture takes in 32 x 32 pixel images and is comprised of 7 convolutional layers.

class ConvNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(3, 8, 3, padding=1)
    self.batchnorm1 = nn.BatchNorm2d(8)
    self.conv2 = nn.Conv2d(8, 8, 3, padding=1)
    self.batchnorm2 = nn.BatchNorm2d(8)
    self.pool2 = nn.MaxPool2d(2)
    self.conv3 = nn.Conv2d(8, 32, 3, padding=1)
    self.batchnorm3 = nn.BatchNorm2d(32)
    self.conv4 = nn.Conv2d(32, 32, 3, padding=1)
    self.batchnorm4 = nn.BatchNorm2d(32)
    self.pool4 = nn.MaxPool2d(2)
    self.conv5 = nn.Conv2d(32, 128, 3, padding=1)
    self.batchnorm5 = nn.BatchNorm2d(128)
    self.conv6 = nn.Conv2d(128, 128, 3, padding=1)
    self.batchnorm6 = nn.BatchNorm2d(128)
    self.pool6 = nn.MaxPool2d(2)
    self.conv7 = nn.Conv2d(128, 10, 1)
    self.pool7 = nn.AvgPool2d(3)

  def forward(self, x):
    #-------------
    # INPUT
    #-------------
    x = x.view(-1, 3, 32, 32)
    
    #-------------
    # LAYER 1
    #-------------
    output_1 = self.conv1(x)
    output_1 = F.relu(output_1)
    output_1 = self.batchnorm1(output_1)

    #-------------
    # LAYER 2
    #-------------
    output_2 = self.conv2(output_1)
    output_2 = F.relu(output_2)
    output_2 = self.pool2(output_2)
    output_2 = self.batchnorm2(output_2)

    #-------------
    # LAYER 3
    #-------------
    output_3 = self.conv3(output_2)
    output_3 = F.relu(output_3)
    output_3 = self.batchnorm3(output_3)

    #-------------
    # LAYER 4
    #-------------
    output_4 = self.conv4(output_3)
    output_4 = F.relu(output_4)
    output_4 = self.pool4(output_4)
    output_4 = self.batchnorm4(output_4)

    #-------------
    # LAYER 5
    #-------------
    output_5 = self.conv5(output_4)
    output_5 = F.relu(output_5)
    output_5 = self.batchnorm5(output_5)

    #-------------
    # LAYER 6
    #-------------
    output_6 = self.conv6(output_5)
    output_6 = F.relu(output_6)
    output_6 = self.pool6(output_6)
    output_6 = self.batchnorm6(output_6)

    #--------------
    # OUTPUT LAYER
    #--------------
    output_7 = self.conv7(output_6)
    output_7 = self.pool7(output_7)
    output_7 = output_7.view(-1, 10)

    return F.softmax(output_7, dim=1)

Convolutional Neural Network Class

In order to neatly put together our model, we will write a class which encompasses both training, validation and model utilization into one object as seen below.

class ConvolutionalNeuralNet():
  def __init__(self, network):
    self.network = network.to(device)
    self.optimizer = torch.optim.Adam(self.network.parameters(), lr=1e-3)

  def train(self, loss_function, epochs, batch_size, 
            training_set, validation_set):
    
    #  creating log
    log_dict = {
        'training_loss_per_batch': [],
        'validation_loss_per_batch': [],
        'training_accuracy_per_epoch': [],
        'training_recall_per_epoch': [],
        'training_precision_per_epoch': [],
        'validation_accuracy_per_epoch': [],
        'validation_recall_per_epoch': [],
        'validation_precision_per_epoch': []
    } 

    #  defining weight initialization function
    def init_weights(module):
      if isinstance(module, nn.Conv2d):
        torch.nn.init.xavier_uniform_(module.weight)
        module.bias.data.fill_(0.01)
      elif isinstance(module, nn.Linear):
        torch.nn.init.xavier_uniform_(module.weight)
        module.bias.data.fill_(0.01)

    #  defining accuracy function
    def accuracy(network, dataloader):
      network.eval()
      
      all_predictions = []
      all_labels = []

      #  computing accuracy
      total_correct = 0
      total_instances = 0
      for images, labels in tqdm(dataloader):
        images, labels = images.to(device), labels.to(device)
        all_labels.extend(labels)
        predictions = torch.argmax(network(images), dim=1)
        all_predictions.extend(predictions)
        correct_predictions = sum(predictions==labels).item()
        total_correct+=correct_predictions
        total_instances+=len(images)
      accuracy = round(total_correct/total_instances, 3)

      #  computing recall and precision
      true_positives = 0
      false_negatives = 0
      false_positives = 0
      for idx in range(len(all_predictions)):
        if all_predictions[idx].item()==1 and  all_labels[idx].item()==1:
          true_positives+=1
        elif all_predictions[idx].item()==0 and all_labels[idx].item()==1:
          false_negatives+=1
        elif all_predictions[idx].item()==1 and all_labels[idx].item()==0:
          false_positives+=1
      try:
        recall = round(true_positives/(true_positives + false_negatives), 3)
      except ZeroDivisionError:
        recall = 0.0
      try:
        precision = round(true_positives/(true_positives + false_positives), 3)
      except ZeroDivisionError:
        precision = 0.0
      return accuracy, recall, precision

    #  initializing network weights
    self.network.apply(init_weights)

    #  creating dataloaders
    train_loader = DataLoader(training_set, batch_size)
    val_loader = DataLoader(validation_set, batch_size)

    #  setting convnet to training mode
    self.network.train()

    for epoch in range(epochs):
      print(f'Epoch {epoch+1}/{epochs}')
      train_losses = []

      #  training
      print('training...')
      for images, labels in tqdm(train_loader):
        #  sending data to device
        images, labels = images.to(device), labels.to(device)
        #  resetting gradients
        self.optimizer.zero_grad()
        #  making predictions
        predictions = self.network(images)
        #  computing loss
        loss = loss_function(predictions, labels)
        log_dict['training_loss_per_batch'].append(loss.item())
        train_losses.append(loss.item())
        #  computing gradients
        loss.backward()
        #  updating weights
        self.optimizer.step()
      with torch.no_grad():
        print('deriving training accuracy...')
        #  computing training accuracy
        train_accuracy, train_recall, train_precision = accuracy(self.network, train_loader)
        log_dict['training_accuracy_per_epoch'].append(train_accuracy)
        log_dict['training_recall_per_epoch'].append(train_recall)
        log_dict['training_precision_per_epoch'].append(train_precision)

      #  validation
      print('validating...')
      val_losses = []

      #  setting convnet to evaluation mode
      self.network.eval()

      with torch.no_grad():
        for images, labels in tqdm(val_loader):
          #  sending data to device
          images, labels = images.to(device), labels.to(device)
          #  making predictions
          predictions = self.network(images)
          #  computing loss
          val_loss = loss_function(predictions, labels)
          log_dict['validation_loss_per_batch'].append(val_loss.item())
          val_losses.append(val_loss.item())
        #  computing accuracy
        print('deriving validation accuracy...')
        val_accuracy, val_recall, val_precision = accuracy(self.network, val_loader)
        log_dict['validation_accuracy_per_epoch'].append(val_accuracy)
        log_dict['validation_recall_per_epoch'].append(val_recall)
        log_dict['validation_precision_per_epoch'].append(val_precision)

      train_losses = np.array(train_losses).mean()
      val_losses = np.array(val_losses).mean()

      print(f'training_loss: {round(train_losses, 4)}  training_accuracy: '+
      f'{train_accuracy}  training_recall: {train_recall}  training_precision: {train_precision} *~* validation_loss: {round(val_losses, 4)} '+  
      f'validation_accuracy: {val_accuracy}  validation_recall: {val_recall}  validation_precision: {val_precision}\n')
      
    return log_dict

  def predict(self, x):
    return self.network(x)

Model Training

Bring this project to life

With everything setup, it's now time to train the model. Using parameters as defined, the model is trained for 15 epochs.

model = ConvolutionalNeuralNet(ConvNet())

log_dict = model.train(nn.CrossEntropyLoss(), epochs=15, batch_size=64, 
                       training_set=training_set, validation_set=validation_set)

From results obtained, both training and validation accuracy increased through the course of model training. Validation accuracy attained a value just under 75%, not the best performing model but will suffice for this article's objectives. Furthermore, both training and validation losses are down-trending indicative of better performance being obtained with more epochs of training.

Accuracy and loss plots.

Model Explainability

In this section we will be attempting to explain/derive insights into the classifications made by the model trained in the previous section. As mentioned previously, we will be using the SHAP library for this purpose.

Basically, the library does this by utilizing the model in classifying a couple of instances in a bid to understand its behavior and the nature of its outputs, this 'understanding' is called the explainer. Afterwards, using the object containing the explainer, values are then assigned to each feature (pixels in this case) which influences the classification made by the model, these values are termed SHAP values. These SHAP values are the actual metrics which imply explainability; based on the magnitude of these values one can develop an idea into how each pertinent feature has contributed to the classification made by the model. Finally, a plot called a SHAP plot is produced to make interpretation of the aforementioned values easier.

Creating a Mask

As mentioned previously, in order to generate SHAP values an explainer has to have been generated prior. This explainer makes classification on some data instances, these data instances are called a mask. For this article, the first 200 instances in the validation set are selected as the mask. There are thereafter converted into a PyTorch dataset by instantiating them as a member of the CustomMask class.

#  defining dataset class
class CustomMask(Dataset):
  def __init__(self, data, transforms=None):
    self.data = data
    self.transforms = transforms

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    image = self.data[idx]

    if self.transforms!=None:
      image = self.transforms(image)
    return image
    
#  creating explainer mask
mask = validation_set.data[:200]

#  turning mask to pytorch dataset
mask = CustomMask(mask, transforms=transforms.ToTensor())

Explainability Function

All the steps outlined above can then be put together to produce a function which implements model explainability by producing SHAP plots for any instance of data classified by the model.

The function below does specifically that. Firstly, it takes in parameters such as an image in array form, a mask and a deep learning model. Next the image array is converted to a tensor and classification is made before mapping the classification vector output to a dictionary of labels native to CIFAR10.

Thereafter, an explainer is derived from the mask and model supplied before SHAP values are produced for the image of choice using this explainer. A SHAP plot is then returned for easy interpretation.

def plot_shap(image_array, mask, model):
  """
  This function performs model explainability
  by producing shap plots for a data instance
  """
  #  converting image to tensor
  image = transforms.ToTensor()(image_array)
  image = image.to(device)

  #-----------------
  #  CLASSIFICATION
  #-----------------
  #  creating a mapping of classes to labels
  label_dict = {0:'airplane', 1:'automobile', 2:'bird', 3:'cat', 4:'deer',
                5:'dog', 6:'frog', 7:'horse', 8:'ship', 9:'truck'}

  #  utilizing the model for classification
  with torch.no_grad():
    prediction = torch.argmax(model(image), dim=1).item()

  #  displaying model classification
  print(f'prediction: {label_dict[prediction]}')

  #----------------
  #  EXPLANABILITY
  #----------------
  #  creating dataloader for mask
  mask_loader = DataLoader(mask, batch_size=200)

  #  creating explainer for model behaviour
  for images in mask_loader:
    images = images.to(device)
    explainer = shap.DeepExplainer(model, images)
    break

  #  deriving shap values for image of interest based on model behaviour
  shap_values = explainer.shap_values(image.view(-1, 3, 32, 32))

  #  preparing for visualization by changing channel arrangement
  shap_numpy = [np.swapaxes(np.swapaxes(x, 1, -1), 1, 2) for x in shap_values]
  image_numpy = np.swapaxes(np.swapaxes(image.view(-1, 3, 32, 32).cpu().numpy(), 1, -1), 1, 2)

  #  producing shap plots
  shap.image_plot(shap_numpy, image_numpy, show=False, labels= ['airplane', 'automobile', 'bird', 
                                                                'cat', 'deer', 'dog','frog',
                                                                'horse', 'ship', 'truck'])
  pass

Understanding SHAP Plots

Utilizing the function written above we can then begin to develop an understanding of why the model classifies an instance of data as it has. For a quick and easy demonstration, we can simply use images in the validation set as seen in the code cell below.

plot_shap(validation_set.data[-150], mask, model.network)

Form the output returned, the model correctly predicts this image instance as a Horse. The ensuing SHAP plot consists of the original image followed by 10 dim grayscale versions of itself. Each grayscale image is indicative of individual classes in the dataset and is labeled as such. Beneath the plot is a scale which reads from negative to positive, color coded from deep blue to bright red. This scale helps to show the intensity of the SHAP value assigned to each pertinent pixel.

Pixels colored deep blue are those which push the model away from predicting that the image belongs to that particular class while pixels colored bright red are those which strongly indicate that the image probably belongs to the class in question; white coloration on the other hand show that no importance was placed on those pixels by the model. Shades of colors in-between those mentioned vary proportionally.

Taking another look at the plot above it can be seen that the model has narrowed down it's gaze on two classes for that particular instance of data, Deer and Horse. In both classes, there are similar patterns of red pixels at the top of the image which implies that objects in that part of the image are synonymous to images of Deers and Horses (ie most Deers and Horses in the training set are pictured on a woodland background as seen in that data instance). However, looking at pixels along the position of the object of interest indicates that the Horse class possesses more red pixels in comparison to the Deer class. This means that the model has perceived that the shape of that object is more synonymous with that of a Horse.

Example 2

Consider the image instance above, again derived from the validation set. This image is correctly classified as a Deer but looking at the SHAP plots, one can see that the model had a more difficult time deciding which class the image belongs to when compared to the previous image. All of the classes are lit up with red and blue pixels in this case with classes automobile, bird and truck less lit than others.

The classes cat, deer, dog, frog and horse have the most activity on their grayscales, particularly on their backgrounds as it seems a significant number of the images in those classes contained in the training set are pictured on grass backgrounds. However, the model has classified the image as a Deer since there are less blue pixels overall compared to the other classes.

Example 3

Unlike the other two images, this data instance which is evidently a dog was misclassified as an airplane. On the surface this might seem like a rather bizarre classification but looking at the SHAP plots more light is shed on why the model behaved this way.

From the plot, both the airplane and the dog class were assumed to be most likely. However, unique differences are seen in the nature of SHAP values along the edges of the grayscales as the ear and neck region of the dog is mostly blue on airplane and red on dog, while regions along the outstretched feet of dog are lit red on airplane and blue on dog.

What this implies is that while the model recognizes that the head and neck region of the image is most likely that of a dog, the fact that the dog is in a stretched out position implies an aerodynamic shape which is most common in airplanes. It is most likely that there are not many images of dogs in that position in the training set for the model to properly learn that distinction.  

Using Imported Images

By extending the function written in the previous section, we can make it so it receives an uploaded image, makes predictions and then provide model explainability via a SHAP plot. This is done below.

def plot_shap_util(filepath, mask, model):
  """
  This function performs model explainability
  by producing shap plots for a data instance
  """
  #  reading image and converting to tensor
  image = cv2.imread(filepath)
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  image = cv2.resize(image, (32, 32))
  image = transforms.ToTensor()(image)
  image = image.to(device)

  #-----------------
  #  CLASSIFICATION
  #-----------------
  #  creating a mapping of classes to labels  
  label_dict = {0:'airplane', 1:'automobile', 2:'bird', 3:'cat', 4:'deer',
                5:'dog', 6:'frog', 7:'horse', 8:'ship', 9:'truck'}

  #  utilizing the model for classification
  prediction = torch.argmax(model(image), dim=1).item()

  #  displaying model classification
  print(f'prediction: {label_dict[prediction]}')

  #----------------
  #  EXPLANABILITY
  #----------------
  #  creating dataloader for mask
  mask_loader = DataLoader(mask, batch_size=200)

  #  creating explainer for model behaviour
  for images in mask_loader:
    images = images.to(device)
    explainer = shap.DeepExplainer(model, images)
    break

  #  deriving shap values for image of interest based on model behaviour
  shap_values = explainer.shap_values(image.view(-1, 3, 32, 32))

  #  preparing for visualization by changing channel arrangement
  shap_numpy = [np.swapaxes(np.swapaxes(x, 1, -1), 1, 2) for x in shap_values]
  test_numpy = np.swapaxes(np.swapaxes(image.view(-1, 3, 32, 32).cpu().numpy(), 1, -1), 1, 2)

  #  producing shap plots
  shap.image_plot(shap_numpy, test_numpy, show=False, labels= ['airplane', 'automobile', 'bird', 'cat', 'deer',
                                                               'dog', 'frog', 'horse', 'ship', 'truck'])
  
  pass

Using the extended function, we can then supply images as parameter and classification will be provided, followed by a SHAP plot which can then be interpreted for explainability.

#  using the extended explainability function
plot_shap_util('image.jpg', mask, model.network)

In this case, the model has correctly classified the uploaded image as that of a Horse since it has less of blue pixels and more of red pixels compared to other classes. Though in this case, a localized region along the base of the image seem to play a huge role in this classification which is difficult to decipher.

Final Remarks

Model explainability helps to provide some useful insight into why a model behaves the way it does even though not all explanations may make sense or be easy to interpret. SHAP is just one way to explain outputs of deep learning models there exist numerous other libraries that can be used to the same effect.

Note: For this article, better explanations can be gotten with a better model. A better model in the context of better architecture and better model performance, feel free to change the model architecture or train the model for more epochs if deemed necessary.

Add speed and simplicity to your Machine Learning workflow today

Get startedTalk to an expert

Spread the word

Keep reading