Working with Custom Image Datasets in PyTorch

In this article, we took a look at working with custom datasets in PyTorch to curated a custom dataset via web scraping, load and label it, and created a PyTorch dataset from it.

2 years ago   •   7 min read

By Oreolorun Olu-Ipinlaye

Bring this project to life

Many beginners may encounter some difficulty while attempting to use a custom, curated dataset with PyTorch. Having previously explored how to curate a custom image dataset (via web scraping), this article will serve as a guide on how to load and label a custom dataset to use with PyTorch.

Creating a Custom Dataset

This section borrows code from the article on curating datasets. The goal here is to curate a custom dataset for a model which is going to distinguish between men's athletic shoes/trainers and men's boots.

For brevity, I will not run through what the code does in detail rather I will provide a quick summary as I believe you must have read the previous article. If you haven't, not to worry: it's linked again here. You could also simply run the code blocks and you will be all set for the next section.

#  article dependencies
import cv2
import numpy as np
import os
import requests
from bs4 import BeautifulSoup
from urllib.request import urlopen
from urllib.request import Request
import time
from torch.utils.data import Dataset
import torch
from torchvision import transforms
from tqdm import tqdm

WebScraper Class

The class below contains methods which will help us curate a custom dataset by parsing html using the beautifulsoup library, extracting image src links using tags and attributes of interests and finally downloading/scraping images of interest from a web page. The methods are named accordingly.

class WebScraper():
    def __init__(self, headers, tag: str, attribute: dict,
                src_attribute: str, filepath: str, count=0):
      self.headers = headers
      self.tag = tag
      self.attribute = attribute
      self.src_attribute = src_attribute
      self.filepath = filepath
      self.count = count
      self.bs = []
      self.interest = []

    def __str__(self):
      display = f"""      CLASS ATTRIBUTES
      headers: headers used so as to mimic requests coming from web browsers.
      tag: html tags intended for scraping.
      attribute: attributes of the html tags of interest.
      filepath: path ending with filenames to use when scraping images.
      count: numerical suffix to differentiate files in the same folder.
      bs: a list of each page's beautifulsoup elements.
      interest: a list of each page's image links."""
      return display

    def __repr__(self):
      display = f"""      CLASS ATTRIBUTES
      headers: {self.headers}
      tag: {self.tag}
      attribute: {self.attribute}
      filepath: {self.filepath}
      count: {self.count}
      bs: {self.bs}
      interest: {self.interest}"""
      return display

    def parse_html(self, url):
      """
      This method requests the webpage from the server and
      returns a beautifulsoup element
      """
      try:
        request = Request(url, headers=self.headers)
        html = urlopen(request)
        bs = BeautifulSoup(html.read(), 'html.parser')
        self.bs.append(bs)
      except Exception as e:
        print(f'problem with webpage\n{e}')
      pass

    def extract_src(self):
      """
      This method extracts tags of interest from the webpage's
      html
      """
      #  extracting tag of interest
      interest = self.bs[-1].find_all(self.tag, attrs=self.attribute)
      interest = [listing[self.src_attribute] for listing in interest]
      self.interest.append(interest)
      pass
    
    def scrape_images(self):
      """
      This method grabs images located in the src links and
      saves them as required
      """
      for link in tqdm(self.interest[-1]):
        try:
          with open(f'{self.filepath}_{self.count}.jpg', 'wb') as f:
            response = requests.get(link)
            image = response.content
            f.write(image)
            self.count+=1
            time.sleep(0.4)
        except Exception as e:
          print(f'problem with image\n{e}')
          time.sleep(0.4)
      pass

Scraping Function

In order to iterate through multiple pages using our web scraper, we need to wrap it in a function which will allow it to do so. The function below is written to that effect as it contains the url of interest formatted as an f-string which will allow for the page reference contained within the url to be iterated.

def my_scraper(scraper, page_range: list):
    """
    This function wraps around the web scraper class allowing it to scrape
    multiple pages. The argument page_range takes both a list of two elements
    to define a range of pages or a list of one element to define a single page.
    """
    if len(page_range) > 1:
      for i in range(page_range[0], page_range[1] + 1):
        scraper.parse_html(url=f'https://www.jumia.com.ng/mlp-fashion-deals/mens-athletic-shoes/?page={i}#catalog-listing')
        scraper.extract_src()
        scraper.scrape_images()
        print(f'\npage {i} done.')
      print('All Done!')
    else:
      scraper.parse_html(url=f'https://www.jumia.com.ng/mlp-fashion-deals/mens-athletic-shoes/?page={page_range[0]}#catalog-listing')
      scraper.extract_src()
      scraper.scrape_images()
      print('\nAll Done!')
    pass

Creating Directories

Since the goal is to curate a dataset of men's shoes, we need to create directories to that effect. For neatness, we create a parent directory in the root named shoes, this directory then contains two sub-directories named athletic and boots which will hold corresponding images.

#  create directories to hold images
os.mkdir('shoes')
os.mkdir('shoes/athletic')
os.mkdir('shoes/boots')

Scraping Images

Bring this project to life

Firstly, we need to define an adequate header for our web scraper. The header helps to mask the scraper as it mimics a request from an actual web browser. Afterwards, we can instantiate a scraper for athletic shoe images using the header we defined, the tag we want to extract images from (img), the attribute of the tags of interest(class: img), the attribute that holds the image links (data-src), the filepath of interest terminating in a filename, and the starting point of the count prefix to be included in the filename. We can then pass the athletic scraper to the my_scraper function since it already contains urls pertaining to athletic shoes.

headers = {'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.11 (KHTML, like Gecko) Chrome/23.0.1271.64 Safari/537.11',
          'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
          'Accept-Charset': 'ISO-8859-1,utf-8;q=0.7,*;q=0.3',
          'Accept-Encoding': 'none',
          'Accept-Language': 'en-US,en;q=0.8',
          'Connection': 'keep-alive'}
#  scrape athletic shoe images
athletic_scraper = WebScraper(headers=headers, tag='img', attribute = {'class':'img'},
                              src_attribute='data-src', filepath='shoes/athletic/atl', count=0)
                        
my_scraper(scraper=athletic_scraper, page_range=[1, 3])

In order to scrape images of boots, copy the two urls in the comments below and replace the current urls in the my_scraper function. The boot scraper is instantiated the same way as the athletic scraper from there on and provided to the my_scraper function in order to scrape boot images.

#  replace the urls in the my scraper function with the urls below
#  first url:
#  f'https://www.jumia.com.ng/mlp-fashion-deals/mens-boots/?page={i}#catalog-listing'
#  second url:
#  f'https://www.jumia.com.ng/mlp-fashion-deals/mens-boots/?page={page_range[0]}#catalog-listing'
#  rerun my_scraper function code cell

#  scrape boot images
boot_scraper = WebScraper(headers=headers, tag='img', attribute = {'class':'img'},
                          src_attribute='data-src', filepath='shoes/boots/boot', count=0)
                        
my_scraper(scraper=boot_scraper, page_range=[1, 3])

When all of these code cells are run in order, a parent directory named 'shoes' should be created in the current working directory. This parent directory should contain two sub directories named 'athletic' and 'boots' which will hold images belonging to those two classes.

Loading & Labeling Images

Now that we have our custom dataset in place, we need to produce array representations of its constituent images (loading), label the arrays and then convert them to tensors for use in PyTorch. Archiving this will require us to define a class which will do all of these processes. The below defined class does the first two steps, it reads images as grayscale, resizes them to 100 x 100 pixels then labels them as desired (athletic shoes = [1, 0], boots = [0, 1]). NOTE: From my point of view, my working directory is the root directory so I have defined filepaths accordingly in the Python class below, you should define filepaths based on your own working directory.

#  defining class to load and label data
class LoadShoeData():
    """
    This class loads in data from each directory in numpy array format then saves
    loaded dataset
    """
    def __init__(self):
        self.athletic = 'shoes/athletic'
        self.boots = 'shoes/boots'
        self.labels = {self.athletic: np.eye(2, 2)[0], self.boots: np.eye(2, 2)[1]}
        self.img_size = 100
        self.dataset = []
        self.athletic_count = 0
        self.boots_count = 0

    def create_dataset(self):
        """
        This method reads images as grayscale from directories,
        resizes them and labels them as required.
        """

        #  reading from directory
        for key in self.labels:
          print(key)

          #  looping through all files in the directory
          for img_file in tqdm(os.listdir(key)):
            try:
              #  deriving image path
              path = os.path.join(key, img_file)

              #  reading image
              image = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
              image = cv2.resize(image, (self.img_size, self.img_size))

              #  appending image and class label to list
              self.dataset.append([image, self.labels[key]])

              #  incrementing counter
              if key == self.athletic:
                self.athletic_count+=1
              elif key == self.boots:
                self.boots_count+=1

            except Exception as e:
              pass

        #  shuffling array of images
        np.random.shuffle(self.dataset)

        #  printing to screen
        print(f'\nathletic shoe images: {self.athletic_count}')
        print(f'boot images: {self.boots_count}')
        print(f'total: {self.athletic_count + self.boots_count}')
        print('All done!')
        return np.array(self.dataset, dtype='object')
#  load data
data = LoadShoeData()

dataset = data.create_dataset()

Running the code cells above should return a NumPy array containing all images in the custom dataset. Each element of that array is an array of its own holding an image and its label.

Creating a PyTorch Dataset

Having produced an array representation of all images and labels in the custom dataset, it is time to create a PyTorch dataset. To do this, we need to define a class which inherits from the PyTorch datasets class as seen below.

#  extending Dataset class
class ShoeDataset(Dataset):
    def __init__(self, custom_dataset, transforms=None):
        self.custom_dataset = custom_dataset
        self.transforms = transforms

    def __len__(self):
        return len(self.custom_dataset)
    
    def __getitem__(self, idx):
        #  extracting image from index and scaling
        image = self.custom_dataset[idx][0]
        #  extracting label from index
        label = torch.tensor(self.custom_dataset[idx][1])
        #  applying transforms if transforms are supplied
        if self.transforms:
          image = self.transforms(image)
        return (image, label)

Basically, two important methods are defined __len__() and __getitem__(). The __len__() method returns the length of the custom dataset while the __getitem__() method grabs an image and its label from the custom dataset via indexing, applies transforms if any and returns a tuple which can then be used by PyTorch.

#  creating an instance of the dataset class
dataset = ShoeDataset(dataset, transforms=transforms.ToTensor())

When the code cell above is run, the dataset object becomes a PyTorch dataset which can now be used in building a deep learning model.

Final Remarks

In this article, we took a look at working with custom datasets in PyTorch. We curated a custom dataset via web scraping, loaded and labeled it and created a PyTorch dataset from it.

The knowledge of Python classes is brought to bare in this article. Most of the processes defined as classes could as well be done with regular functions (except the PyTorch dataset class) but as a personal preference I chose to do it as I have done. Feel free to attempt to replicate this code by doing what works best for you at this point in your programming journey.

Bring this project to life

Spread the word

Keep reading