Object Detection Using Mask R-CNN with TensorFlow 2.0 and Keras

3 years ago   •   9 min read

By Ahmed Fawzy Gad

In a previous tutorial, we saw how to use the open-source GitHub project Mask_RCNN with Keras and TensorFlow 1.14. In this tutorial, the project is inspected to replace the TensorFlow 1.14 features by those compatible with TensorFlow 2.0.

Specifically, we'll cover:

  • Four edits to make predictions with Mask R-CNN using TensorFlow 2.0
  • Five edits to train Mask R-CNN with TensorFlow 2.0
  • A summary of all the changes to be made
  • Conclusion

Before starting, check out the previous tutorial to download and run the Mask_RCNN project.

Bring this project to life

Edits to Make Predictions with Mask R-CNN Using TensorFlow 2.0

The Mask_RCNN project works only with TensorFlow $\geq$ 1.13. Because TensorFlow 2.0 offers more features and enhancements, developers are looking to migrate to TensorFlow 2.0.

Some tools may help in automatically convert TensorFlow 1.0 code to TensorFlow 2.0 but they are not guaranteed to produce a fully functional code. Check the upgrade script offered by Google.

In this section, the required changes to the Mask R-CNN project are discussed so that it fully supports TensorFlow 2.0 for making predictions (i.e. when the mode parameter in the mrcnn.model.MaskRCNN class constructor is set to inference). In a later section, more edits are applied to train the Mask R-CNN model in TensorFlow 2.0 (i.e. when the mode parameter in the mrcnn.model.MaskRCNN class constructor is set to training).

If you have TensorFlow 2.0 installed, running the following code block to perform inference will raise exceptions. Please consider downloading the trained weights mask_rcnn_coco.h5 from this link.

import mrcnn
import mrcnn.config
import mrcnn.model
import mrcnn.visualize
import cv2
import os

CLASS_NAMES = ['BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']

class SimpleConfig(mrcnn.config.Config):
    NAME = "coco_inference"
    
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1

    NUM_CLASSES = len(CLASS_NAMES)

model = mrcnn.model.MaskRCNN(mode="inference", 
                             config=SimpleConfig(),
                             model_dir=os.getcwd())

model.load_weights(filepath="mask_rcnn_coco.h5", 
                   by_name=True)

image = cv2.imread("sample2.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

r = model.detect([image], verbose=0)

r = r[0]

mrcnn.visualize.display_instances(image=image, 
                                  boxes=r['rois'], 
                                  masks=r['masks'], 
                                  class_ids=r['class_ids'], 
                                  class_names=CLASS_NAMES, 
                                  scores=r['scores'])

The next subsections discuss the required changes to support TensorFlow 2.0 and to solve all the exceptions.

1. tf.log()

Running the previous code, an exception is raised from this line in the mrcnn.model.log2_graph() function:

return tf.log(x) / tf.log(2.0)

The exception text is given below. It indicates that TensorFlow has no attribute called log().

...
File "D:\Object Detection\Tutorial\code\mrcnn\model.py", in log2_graph
  return tf.log(x) / tf.log(2.0)

AttributeError: module 'tensorflow' has no attribute 'log'

In TensorFlow $\geq$ 1.0, the log() function was available at the root of the library. Due to reorganizing some functions in TensorFlow 2.0, the log() function is moved into the tensorflow.math module. So, rather than using tf.log(), simply use tf.math.log().

To fix the issue, just locate the mrcnn.model.log2_graph() function. Here is its code:

def log2_graph(x):
    """Implementation of Log2. TF doesn't have a native implementation."""
    return tf.log(x) / tf.log(2.0)

Replace each tf.log by tf.math.log. The new function should be:

def log2_graph(x):
    """Implementation of Log2. TF doesn't have a native implementation."""
    return tf.math.log(x) / tf.math.log(2.0)

2. tf.sets.set_intersection()

After running the code again, another exception is raised when executing this line inside the mrcnn.model.refine_detections_graph() function:

keep = tf.sets.set_intersection(tf.expand_dims(keep, 0), tf.expand_dims(conf_keep, 0))

The exception is given below.

File "D:\Object Detection\Tutorial\code\mrcnn\model.py", line 720, in refine_detections_graph
  keep = tf.sets.set_intersection(tf.expand_dims(keep, 0),

AttributeError: module 'tensorflow_core._api.v2.sets' has no attribute 'set_intersection'

The issue occurs because the set_intersection() function in TensorFlow $\geq$ 1.0 is renamed to intersection() in TensorFlow 2.0.

To fix the issue, just use tf.sets.intersection() rather than tf.sets.set_intersection(). The new line is:

keep = tf.sets.set.intersection(tf.expand_dims(keep, 0), tf.expand_dims(conf_keep, 0))

Note that tf.sets.set_intersection() is used in another location. So, search for all of its occurrences and replace each one by tf.sets.intersection().

3. tf.sparse_tensor_to_dense()

After the two above issues are solved, running the code again gives an exception in this line from the mrcnn.model.refine_detections_graph() function:

keep = tf.sparse_tensor_to_dense(keep)[0]

Here is the exception. The function sparse_tensor_to_dense() in TensorFlow $\geq$ 1.0 is accessible through the tf.sparse module (tf.sparse.to_dense).

File "D:\Object Detection\Tutorial\code\mrcnn\model.py", in refine_detections_graph
  keep = tf.sparse_tensor_to_dense(keep)[0]

AttributeError: module 'tensorflow' has no attribute 'sparse_tensor_to_dense'

To fix it, replace each occurrence of tf.sparse_tensor_to_dense by tf.sparse.to_dense. The new line should be:

keep = tf.sparse.to_dense(keep)[0]

Do that for all occurrences of tf.sparse_tensor_to_dense.

4. tf.to_float()

There is another exception raised due to this line inside the mrcnn.model.load_image_gt() function:

tf.to_float(tf.gather(class_ids, keep))[..., tf.newaxis]

The exception below occurs because the to_float() function in TensorFlow $\geq$ 1.0 does not exist in TensorFlow 2.0.

File "D:\Object Detection\Tutorial\code\mrcnn\model.py", in refine_detections_graph
  tf.to_float(tf.gather(class_ids, keep))[..., tf.newaxis],

AttributeError: module 'tensorflow' has no attribute 'to_float'

As a replacement for the to_float() function in TensorFlow 2.0, use the tf.cast() function as follows:

tf.cast([value], tf.float32)

To fix the exception, replace the previous line by the next line:

tf.cast(tf.gather(class_ids, keep), tf.float32)[..., tf.newaxis]

Summary of changes for inference

To make predictions using Mask R-CNN in TensorFlow 2.0, there are 4 changes to be made in the mrcnn.model script:

  1. Replace tf.log() by tf.math.log()
  2. Replace tf.sets.set_intersection() by tf.sets.intersection()
  3. Replace tf.sparse_tensor_to_dense() by tf.sparse.to_dense()
  4. Replace tf.to_float() by tf.cast([value], tf.float32)

After making all of these changes, the code we saw at the beginning of this article can successfully run in TensorFlow 2.0.

Edits to Train Mask R-CNN Using TensorFlow 2.0

Assuming that you have TensorFlow 2.0 installed, running the code block below to train Mask R-CNN on the Kangaroo Dataset will raise a number of exceptions. This section inspects the changes to be made to train Mask R-CNN in TensorFlow 2.0.

Please consider downloading the Kangaroo dataset, in addition to the weights mask_rcnn_coco.h5 from this link.

import os
import xml.etree
from numpy import zeros, asarray

import mrcnn.utils
import mrcnn.config
import mrcnn.model

class KangarooDataset(mrcnn.utils.Dataset):

	def load_dataset(self, dataset_dir, is_train=True):
		self.add_class("dataset", 1, "kangaroo")

		images_dir = dataset_dir + '/images/'
		annotations_dir = dataset_dir + '/annots/'

		for filename in os.listdir(images_dir):
			image_id = filename[:-4]

			if image_id in ['00090']:
				continue

			if is_train and int(image_id) >= 150:
				continue

			if not is_train and int(image_id) < 150:
				continue

			img_path = images_dir + filename
			ann_path = annotations_dir + image_id + '.xml'

			self.add_image('dataset', image_id=image_id, path=img_path, annotation=ann_path)

	def extract_boxes(self, filename):
		tree = xml.etree.ElementTree.parse(filename)

		root = tree.getroot()

		boxes = list()
		for box in root.findall('.//bndbox'):
			xmin = int(box.find('xmin').text)
			ymin = int(box.find('ymin').text)
			xmax = int(box.find('xmax').text)
			ymax = int(box.find('ymax').text)
			coors = [xmin, ymin, xmax, ymax]
			boxes.append(coors)

		width = int(root.find('.//size/width').text)
		height = int(root.find('.//size/height').text)
		return boxes, width, height

	def load_mask(self, image_id):
		info = self.image_info[image_id]
		path = info['annotation']
		boxes, w, h = self.extract_boxes(path)
		masks = zeros([h, w, len(boxes)], dtype='uint8')

		class_ids = list()
		for i in range(len(boxes)):
			box = boxes[i]
			row_s, row_e = box[1], box[3]
			col_s, col_e = box[0], box[2]
			masks[row_s:row_e, col_s:col_e, i] = 1
			class_ids.append(self.class_names.index('kangaroo'))
		return masks, asarray(class_ids, dtype='int32')

class KangarooConfig(mrcnn.config.Config):
    NAME = "kangaroo_cfg"

    GPU_COUNT = 1
    IMAGES_PER_GPU = 1
    
    NUM_CLASSES = 2

    STEPS_PER_EPOCH = 131

train_set = KangarooDataset()
train_set.load_dataset(dataset_dir='kangaroo', is_train=True)
train_set.prepare()

valid_dataset = KangarooDataset()
valid_dataset.load_dataset(dataset_dir='kangaroo', is_train=False)
valid_dataset.prepare()

kangaroo_config = KangarooConfig()

model = mrcnn.model.MaskRCNN(mode='training', 
                             model_dir='./', 
                             config=kangaroo_config)

model.load_weights(filepath='mask_rcnn_coco.h5', 
                   by_name=True, 
                   exclude=["mrcnn_class_logits", "mrcnn_bbox_fc",  "mrcnn_bbox", "mrcnn_mask"])

model.train(train_dataset=train_set, 
            val_dataset=valid_dataset, 
            learning_rate=kangaroo_config.LEARNING_RATE, 
            epochs=1, 
            layers='heads')

model_path = 'Kangaro_mask_rcnn.h5'
model.keras_model.save_weights(model_path)

1. tf.random_shuffle()

After running the previous code, an exception is raised when the following line inside the mrcnn.model.detection_targets_graph() function is executed.

positive_indices = tf.random_shuffle(positive_indices)[:positive_count]

The exception is given below, which indicates that no function is named random_shuffle().

File "D:\mrcnn\model.py", in detection_targets_graph
  positive_indices = tf.random_shuffle(positive_indices)[:positive_count]

AttributeError: module 'tensorflow' has no attribute 'random_shuffle'

Due to the new organization of TensorFlow 2.0 functions, the tf.random_shuffle() function in TensorFlow 1.0 is replaced by the shuffle() method in the tf.random module. Thus, tf.random_shuffle() should be replaced by tf.random.shuffle().

The previous line should be:

positive_indices = tf.random.shuffle(positive_indices)[:positive_count]

Please check all occurrences of tf.random_shuffle() and make the necessary change.

2. tf.log

There is an exception raised from the next line inside the mrcnn.utils.box_refinement_graph() function.

dh = tf.log(gt_height / height)

The exception, given below, indicates that the function tf.log() does not exist.

File "D:\mrcnn\utils.py", in box_refinement_graph
  dh = tf.log(gt_height / height)

AttributeError: module 'tensorflow' has no attribute 'log'

In TensorFlow 2.0, the log() function is moved into the math module. Thus, tf.log() should be replaced by tf.math.log(). The previous line should be:

dh = tf.math.log(gt_height / height)

Make this change for all occurrences of the tf.log() function.

3. Tensor Membership

An exception is raised when executing the next if statement inside the compile() method in the mrccn.model.MaskRCNN class.

if layer.output in self.keras_model.losses:

The exception is listed below. Let's explain its meaning.

Both layer.output and self.keras_model.losses are tensors. The previous line checks the membership of the layer.output tensor inside the self.keras_model.losses tensor. The result of the membership operation is another tensor, and Python uses it as a bool type, which is impossible.

File "D:\mrcnn\model.py", in compile
  if layer.output in self.keras_model.losses:
	...
OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.

According to the following code, the purpose of the if statement is to check whether the layer's loss exists within the self.keras_model.losses tensor. If not, then the code appends it to the self.keras_model.losses tensor.

...

loss_names = ["rpn_class_loss",  "rpn_bbox_loss",
              "mrcnn_class_loss", "mrcnn_bbox_loss", 
              "mrcnn_mask_loss"]

for name in loss_names:
    layer = self.keras_model.get_layer(name)

    if layer.output in self.keras_model.losses:
        continue

    loss = (tf.reduce_mean(layer.output, keepdims=True) * self.config.LOSS_WEIGHTS.get(name, 1.))
    self.keras_model.add_loss(loss)

...

Outside the for loop, the self.keras_model.losses tensor is empty. Thus, it has no loss functions at all. As a result, the if statement might be ignored. The solution is to comment the if statement according to the next code.

...

loss_names = ["rpn_class_loss",  "rpn_bbox_loss",
              "mrcnn_class_loss", "mrcnn_bbox_loss", 
              "mrcnn_mask_loss"]

for name in loss_names:
    layer = self.keras_model.get_layer(name)

    # if layer.output in self.keras_model.losses:
    #     continue

    loss = (tf.reduce_mean(layer.output, keepdims=True) * self.config.LOSS_WEIGHTS.get(name, 1.))
    self.keras_model.add_loss(loss)

...

Let's discuss another exception.

4. metrics_tensors

There is an exception raised after executing the next line inside the compile() method in the mrccn.model.MaskRCNN class.

self.keras_model.metrics_tensors.append(loss)

According to the following error, there is no attribute named metrics_tensors in the keras_model attribute.

File "D:\mrcnn\model.py", in compile
  self.keras_model.metrics_tensors.append(loss)

AttributeError: 'Model' object has no attribute 'metrics_tensors'

The solution is to add metrics_tensors to the beginning of the compile() method.

class MaskRCNN():
    ...
    def compile(self, learning_rate, momentum):
        self.keras_model.metrics_tensors = []
        ...

The next section discusses the last change to be made.

5. Save Training Logs

There is an exception raised after executing the next line inside the set_log_dir() method within the mrcnn.model.MaskRCNN class.

self.log_dir = os.path.join(self.model_dir, "{}{:%Y%m%dT%H%M}".format(self.config.NAME.lower(), now))

The exception is given below, and indicates there is a problem creating a directory.

NotFoundError: Failed to create a directory: ./kangaroo_cfg20200918T0338\train\plugins\profile\2020-09-18_03-39-26; No such file or directory

The exception can be solved by manually specifying a valid directory. For example, the next directory is valid for my PC. Try to specify a valid directory for yours.

self.log_dir = "D:\\Object Detection\\Tutorial\\logs"

This is the last change to be made so that the Mask_RCNN project can train the Mask R-CNN model in TensorFlow 2.0. The training code prepared previously can now be executed in TensorFlow 2.0.

Summary of changes to train Mask R-CNN in TensorFlow 2.0

To train the Mask R-CNN model using the Mask_RCNN project in TensorFlow 2.0, there are 5 changes to be made in the mrcnn.model script:

  1. Replace tf.random_shuffle() with tf.random.shuffle()
  2. Replace tf.log() with tf.math.log()
  3. Comment out an if statement inside the compile() method.
  4. Initialize the metrics_tensors attribute at the beginning of the compile() method.
  5. Assign a valid directory to the self.log_dir attribute.

Conclusion

This tutorial edited the open-source Mask_RCNN project so that the Mask R-CNN model is able to be trained and perform inference using TensorFlow 2.0.

To train the Mask R-CNN model in TensorFlow 2.0, a total of 9 changes were applied: 4 to support making predictions, and 5 to enable training.

Add speed and simplicity to your Machine Learning workflow today

Get startedContact Sales

Spread the word

Keep reading