One of the great promises of Deep Learning is its applicability in a wide variety of complex tasks. Recent years have seen an explosion in the number of fields Deep Learning has seen successful application in. There have been enormous breakthroughs in the fields of Biology, Chemistry, Healthcare, and Physics in particular.
At Paperspace, part of our mission is to empower anyone interested in ML research, whether they are a seasoned practitioner or a relative newcomer, with tools that greatly improve and expedite their productivity. Both Andrew Ng and Jeremy Howard have commented on how Deep Learning will really empower domain experts to get incredible breakthroughs in their respective fields, and organizations like DeepMind have achieved incredible things by applying Deep Learning to very specific domains like Protein Folding.
In this post, we’re going to be demonstrating how to to build a state of the art Bacterial Classification model on Gradient using the Fast.ai machine learning library. We'll start by understanding the task and examining our dataset. After this, we'll make some decisions about our architecture and training process, and evaluate our results compared to the current state of the art!
Bring this project to life
Understanding Bacterial Classification
While it may seem obscure, the task of classifying bacterial species is actually very useful because of their prevalence in our environment and information about them is significant in many fields, including agriculture and medicine. Building a system that can automatically recognize and classify these microorganisms would be incredibly useful in these fields, and is an open research question today. It is a surprisingly complex task. The shape of individual bacterial cells can vary tremendously, but so does their frequency in a scene. When examining colonies of bacteria, factors like Colony size, texture, and composition come into play.
The data we'll be using today comes from the Digital Image of Bacterial Species dataset(DIBaS), compiled as part of the study in Deep learning approach to bacterial colony classification (Zieliński et al., 2017). It contains 660 images with 33 different genera and species of bacteria. We'll be examining their results more carefully and comparing our own later on in the post!
Preprocessing our Data
The work here was achieved using Paperspace's Gradient notebook feature with the Fast.ai template. All the packages used are already installed and accessible in this container, which makes for a quick start. DIBaS is actually a little hard to access automatically, as it is siloed into separate links on its website. So, to automate and save ourselves some time, we'll make use of a scraping library to collect and parse our data! Let's import some useful packages.
import requests
import urllib.request
import time
from bs4 import BeautifulSoup
import os
The package to keep an eye on is BeautifulSoup
, which allows us to parse an HTML page after we grab it to search for a useful URL (like one that holds our download link).
Let's grab the web page from the DIBaS site and parse it!
url = 'http://misztal.edu.pl/software/databases/dibas/'
response = requests.get(url)
soup = BeautifulSoup(response.text, "html.parser")
os.mkdir('./bacteria-dataset/full_images_alt')
Now that we have our parsed URLs for each subfolder in the bacterial species dataset, we can use the urllib
library to download the zipfiles and unpack them!
for i in range(19,52): #'a' tags are for links
one_a_tag = soup.findAll('a')[i]
link = one_a_tag['href']
urllib.request.urlretrieve(link,'./bacteria dataset/zip_files/'+link[49:len(link)])
time.sleep(1)
import zipfilefor i in range(20,52):
one_a_tag = soup.findAll('a')[i]
link = one_a_tag['href']
zip_ref = zipfile.ZipFile('./bacteria-dataset/zip_files/'+link[49:len(link)], 'r')
zip_ref.extractall('./bacteria-dataset/full_images_alt/')
zip_ref.close()
Training Our Model
Now that our data is ready, we can move onto training our model. We're going to make use of the Fast.ai library for its ease of use, high level abstractions, and powerful API. If you've taken the first lesson from the Practical Deep Learning for Coders course form Fast.ai (also known as Fast.ai Part 1), you're ready to understand everything we're doing here!
First, let's import the right modules from the library.
from fastai.vision import *
from fastai.metrics import error_rate
We can set some configurations and grab our files using the get_image_files utility from Fast.ai.
bs = 64
fnames = get_image_files('bacteria-dataset/full_images_alt')
fnames[:5]
# Outputs are filepaths!
# [PosixPath('bacteria-dataset/full_images_alt/Actinomyces.israeli_0001.tif'),
# PosixPath('bacteria-dataset/full_images_alt/Actinomyces.israeli_0002.tif'),
# PosixPath('bacteria-dataset/full_images_alt/Actinomyces.israeli_0003.tif'),
# PosixPath('bacteria-dataset/full_images_alt/Actinomyces.israeli_0004.tif'),
# PosixPath('bacteria-dataset/full_images_alt/Actinomyces.israeli_0005.tif')]
Now, we'll make use of the ImageDataBunch
class from Fast.ai, which basically creates a data structure that holds your dataset and labels automatically based on how the dataset folders have been organized. In this case, we setup our dataset to facilitate the use of this class, and it works out of the box!
np.random.seed(42)
pat = r'/([^/]+)_\d+.tif$'
data = ImageDataBunch.from_name_re('bacteria-dataset/full_images_alt', fnames, pat, ds_tfms=get_transforms(), size=224, bs=bs).normalize(imagenet_stats)
Now, we can create a CNN architecture to learn from our dataset. CNNs prove to be really useful here because we are trying to learn visual features and structure independent of locality. We'll use ResNet34, which may work very well. ResNets haven't been used in this task, and are a good area to explore. You can find a useful overview of ResNets here, and I've also included the original paper in the references section of this post!
learn = create_cnn(data, models.resnet34, metrics=error_rate)
Now, to train our model, we'll make use of the fit_one_cycle method
. This method makes use of the strategy in Leslie Smith's exciting paper. It uses different hyperparam configurations and certain discovered rules to significantly reduce training time and improve performance. We can see the outputs from the training process below.
learn.fit_one_cycle(4)
# Outputs:
# epoch train_loss valid_loss error_rate
# 1 3.817713 2.944878 0.759124
# 2 2.632171 1.093049 0.248175
# 3 1.929509 0.544141 0.167883
# 4 1.509456 0.457186 0.145985
Wow! Our model actually does kind of well, getting to an error rate of ~14.5%! The real question is how does this compare to the state of the art? Well, the seminal work in the area is the paper that included the DIBaS dataset in the first place. They tested several different methods ranging from CNNs to more traditional methods like SVMs with various and Random Forests. Their best results were around 97% accuracy, so a good deal better than ours.
So, how can we improve our approach? Well, Resnets are pretty powerful, so we may want to use a heavier architecture like ResNet50. We can use the lr_find()
method to find the optimal learning rate and use that to try and improve our model.
learn = create_cnn(data, models.resnet50, metrics=error_rate)
learn.lr_find()
learn.recorder.plot()
This graph inform us as to where the learning rate is best impacting the los – pretty cool! For reference sake, let's train without using our knowledge of this span of learning rate values. We can use the same training method for 8 cycles.
learn.fit_one_cycle(8)
# epoch train_loss valid_loss error_rate
# 1 2.853813 1.561166 0.306569
# 2 1.639013 0.248170 0.058394
# 3 1.101536 0.230741 0.080292
# 4 0.781610 0.159655 0.043796
# 5 0.587977 0.132877 0.036496
# 6 0.455316 0.115520 0.036496
# 7 0.356362 0.108675 0.029197
# 8 0.293171 0.109001 0.029197
Interesting! We can see that this model is a lot better than our previous one, and is basically the same as the performance outlined in the paper! 97.1% accuracy is nothing to laugh at!
But what if we use the knowledge of the learning rates we picked up earlier? Let's confine our one cycle training process to work in the range where the learning rate impacted the loss the most.
learn.save('stage-1-50')
learn.unfreeze()
learn.fit_one_cycle(3, max_lr=slice(1e-6,1e-4))
# Outputs
# epoch train_loss valid_loss error_rate
# 1 0.178638 0.100145 0.021898
# 2 0.176825 0.093956 0.014599
# 3 0.159130 0.092905 0.014599
Wow! Our new model achieves an accuracy of 98.5%, which definitely beats the original paper. Of course, the original paper is from 2017, and it makes sense that applying a very powerful model like ResNet would yield great results.
Conclusion
We've managed to get some pretty amazing results on a task from a domain many people may not be familiar with, and we've done it pretty quickly thanks to Gradient and Fast.ai. Of course, its useful there hasn't been much progress in this domain since 2017, so there may be better, more nuanced approaches than throwing ResNets at the problem. Going forward, we may try different approaches to this bacterial classification task or maybe even try and tackle some other new datasets! What other deep learning architectures do you think may be useful?
If you have an ML related project or idea that you've been wanting to try, consider using Paperspace's Gradient platform! It's a fantastic tool that allows you to many useful things, including explore and visualize with notebooks, run serious training cycles and ML pipelines, and also deploy your trained models as endpoints! Learn more about our Free GPU enabled Jupyter Notebooks here!
References
He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 770-778).
Smith, L. N. (2018). A disciplined approach to neural network hyper-parameters: Part 1--learning rate, batch size, momentum, and weight decay. arXiv preprint arXiv:1803.09820.
Zieliński, B., Plichta, A., Misztal, K., Spurek, P., Brzychczy-Włoch, M., & Ochońska, D. (2017). Deep learning approach to bacterial colony classification. PloS one, 12(9), e0184554.