Framed Butterfly Displays
Framed Butterfly Displays (source: Ryan Somma on Flickr)

Image classification can perform some pretty amazing feats, but a large drawback of many image classification applications is that the model can only detect one class per image. With an object detection model, not only can you classify multiple classes in one image, but you can specify exactly where that object is in an image with a bounding box framing the object.

The TensorFlow Models GitHub repository has a large variety of pre-trained models for various machine learning tasks, and one excellent resource is their object detection API. The object detection API makes it extremely easy to train your own object detection model for a large variety of different applications. Whether you need a high-speed model to work on live stream high-frames-per-second (fps) applications or high-accuracy desktop models, the API makes it easy to train and export a model.

This tutorial will walk through all the steps for building a custom object classification model using TensorFlow’s API.

Gathering a data set

Some very large detection data sets, such as Pascal and COCO, exist already, but if you want to train a custom object detection class, you have to create and label your own data set.

For my data set, I decided to collect images of chess pieces from internet image searches. I started using only images of white and black pawns, but I’m hoping to include all the chess pieces in the future. I gathered all my images from search engines, so I decided to make a list of links in a text file that can be downloaded later using a script and scikit image. Ideally, you want at least 100-300 training images; for the chess pieces, unfortunately I could only find about 75 per class. We’ll see how the model does at the end of this post. Due to my limited amount of data, I split my test files to 15%; ideally, you would have 30% of all your data for testing. For convenience, I decided to resize all my images to 300 x 300 pixels before saving them so I could create my bounding boxes and not worry about having to resize the images down the line.

#This function will download and resize all images in the imageLinks folder and will split into train and test folders with their associated label.

#Editor's note: It is your responsibility to ensure that use of copyrighted images accessed in connection with this script complies with any license restrictions that may apply.

copyLabels = True
trainPercent = 0.85

listing = os.listdir(linksPath) 
for classes in listing:
    os.chdir(linksPath)
    text = open(classes, 'r')
    links = text.readlines()
    links = [i.strip() for i in links]
    
    cut = int(np.floor(len(links)*trainPercent))
    
    for i in range(cut):
        os.chdir(trainPath)
        if check(links[i]):
            image = skimage.io.imread(links[i])
            image = skimage.transform.resize(image, [300,300])
            skimage.io.imsave(classes[:-4]+str(i)+'.jpg', image)
            if copyLabels:
                label = classes[:-4]+str(i)+'.xml'
                shutil.copyfile(labelsPath+'/'+label,trainPath+'/'+label) 
        
    for i in range(cut,len(links)):
        os.chdir(testPath)
        if check(links[i]):
            image = skimage.io.imread(links[i])
            image = skimage.transform.resize(image, [300,300])
            skimage.io.imsave(classes[:-4]+str(i)+'.jpg', image)
            if copyLabels:
                label = classes[:-4]+str(i)+'.xml'
                shutil.copyfile(labelsPath+'/'+label,testPath+'/'+label) 

Creating bounding boxes

In order to train our object detection model, for each image we will need the image’s width, height, and each class with their respective xmin, xmax, ymin, and ymax bounding box. Simply put, our bounding box is the frame that captures exactly where our class is in the image.

chess set
Figure 1. Bounding box. Source: Pixabay, released under Creative Commons CC0.

Creating these labels can be a huge ordeal, but thankfully there are programs that help create bounding boxes. LabelImg is an excellent open source free software that makes the labeling process much easier. It will save individual xml labels for each image, which we will convert into a csv table for training. The labels for all the images used in the pawn detector we are building are included in the GitHub repository.

Install the object detection API

Before getting started, we have to clone and install the object detection API into our GitHub repository. Installing the object detection API is extremely simple; you just need to clone the TensorFlow Models directory and add some things to your Python path. The full installation process for Docker or native Python is noted in the GitHub repository Readme.

pip3 install -r requirements.txt
apt-get install -y protobuf-compiler 
git clone https://github.com/tensorflow/models.git
cd models/research/
protoc object_detection/protos/*.proto --python_out=.
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim

Convert labels to the TFRecord format

When training models with TensorFlow using TFRecord, files help optimize your data feed. We can generate a TFRecord file using code adapted from this raccoon detector.

# Modified From:
# https://github.comr/datitran/raccoon_dataset/blob/master/xml_to_csv.py

os.chdir(root)
def xml_to_csv(path):
    xml_list = []
    for xml_file in glob.glob(path + '/*.xml'):
        tree = ET.parse(xml_file)
        root = tree.getroot()
        for member in root.findall('object'):
            value = (root.find('filename').text,
                     int(root.find('size')[0].text),
                     int(root.find('size')[1].text),
                     member[0].text,
                     int(member[4][0].text),
                     int(member[4][1].text),
                     int(member[4][2].text),
                     int(member[4][3].text))
            xml_list.append(value)
    column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
    xml_df = pd.DataFrame(xml_list, columns=column_name)
    return xml_df

def main():
    for i in [trainPath, testPath]:
        image_path = i
        folder = os.path.basename(os.path.normpath(i))
        xml_df = xml_to_csv(image_path)
        xml_df.to_csv('data/'+folder+'.csv', index=None)
        print('Successfully converted xml to csv.')
    
main()

Choose a model

There are models in the TensorFlow API you can use depending on your needs. If you want a high-speed model that can work on detecting video feed at high fps, the single shot detection (SSD) network works best. Some other object detection networks detect objects by sliding different sized boxes across the image and running the classifier many times on different sections of the image; this can be very resource consuming. As its name suggests, the SSD network determines all bounding box probabilities in one go; hence, it is a vastly faster model. However, with single shot detection, you gain speed at the cost of accuracy. I’ll use single shot detection as the bounding box framework, but for the neural network architecture, I will use the MobileNet model, which is designed to be used in mobile applications. I’ve already configured the config file for SSD MobileNet and included it in the GitHub repository for this post. Depending on your computer, you may have to lower the batch size in the config file if you run out of memory.

Retrain the model with your data

Now you could train the entire SSD MobileNet model on your own data from scratch. In order to do this, though, you would need thousands of training images, multiple GPUs, and roughly a week of training time. The much easier solution is to take a model already trained on a large data set and clip off the last layer, which has the classes from the trained model, and replace it with your own classes. By doing this, you use all the feature detectors trained in the previous model and use these features to try to detect your new classes. Since we are only retraining the last layer of our mobilenet model, a high-end GPU is not required (but it can certainly speed things up). Once our loss is consistently around the value of 1 or starts rising, we can stop TensorFlow training by pressing ctrl+c. To train, we simply run the `train.py` file in the object detection API directory.

python3 models/research/object_detection/train.py --logtostderr --train_dir=data/ --pipeline_config_path=data/ssd_mobilenet_v1_pets.config

Implement new model with TensorFlow

Before we start experimenting with our newly trained model, we have to export our graph for inference. You can use the latest ckpt # from your data directory.

python3 models/research/object_detection/export_inference_graph.py \
    --input_type image_tensor \
    --pipeline_config_path data/ssd_mobilenet_v1_pets.config \
    --trained_checkpoint_prefix data/model.ckpt-997 \
    --output_directory object_detection_graph

After this, we can start having some fun using some code modified from TensorFlow’s Object Detection Notebook. To validate our model’s performance, we will use images that it has never seen before; for my validation images, I used some pictures of my own chess set. If you are building a serious detection model, I highly recommend assigning about 10% of your total images to be validation images. With a good number of validation images, you can test multiple checkpoints to see which one performs best.

chess set 1
Figure 2. Pictures of my chess set. Image courtesy of Justin Francis.
chess set 2
Figure 3. Pictures of my chess set. Image courtesy of Justin Francis.

It was very clear to see that my very limited amount of data, about 100 images per class, was not enough to get a robust model. Though my detector was able to detect direct front shots of the pawns, it was not able to detect pawns that were blurry, at a distance, at an angle, or slightly covered. But, I believe this toy example showcases the API’s capabilities well. It was my goal to gather all the steps to creating a custom object detection model in one spot, and I highly recommend you experiment with all the models. Taking this tutorial a step further, you could use the frozen model on a mobile device using TensorFlow’s Android Camera Demo. I really hope you use the tools provided to create your own custom object detection model.

Article image: Framed Butterfly Displays (source: Ryan Somma on Flickr).