Chapter 4. Neural Network Development Reference Designs

In the previous chapter we covered NN development process at a high level, and you learned how to implement each stage in PyTorch. The examples in that chapter focused on solving an image classification problem with the CIFAR-10 dataset and a simple fully connected network. CIFAR-10 image classification is a good academic example to illustrate the NN development process, but there’s a lot more to developing deep learning models with PyTorch.

This chapter presents some additional reference designs for NN development with PyTorch. Reference designs are code examples that you can use as a reference to solve similar types of problems.

Indeed, the set of reference designs in this chapter merely scratches the surface when it comes to the possibilities of deep learning; however, I’ll attempt to provide you with enough variety to assist you in the development of your own solutions. We will use three examples to process a variety of data, design different model architectures, and explore other approaches to the learning process.

The first example uses PyTorch to perform transfer learning to classify images of bees and ants with a small dataset and a pretrained network. The second example uses PyTorch to perform sentiment analysis using text data to train an NLP model that predicts the positive or negative sentiment of movie reviews. And the third example uses PyTorch to demonstrate generative learning by training a generative adversarial network (GAN) to generate images of articles of clothing.

In each example, I’ll provide PyTorch code so that you can use this chapter as a quick reference when writing code for your own designs. Let’s begin by seeing how PyTorch can solve a computer vision problem using transfer learning.

Image Classification with Transfer Learning

The subject of image classification has been studied in depth, and many famous models, like the AlexNet and VGG models we saw earlier, are readily available through PyTorch. However, these models have been trained with the ImageNet dataset. Although ImageNet contains 1,000 different image classes, it may not contain the classes that you need to solve your image classification problem.

In this case, you can apply transfer learning, a process in which we fine-tune pretrained models with a much smaller dataset of new images. For our next example, we will train a model to classify images of bees and ants—classes not contained in ImageNet. Bees and ants look very similar and can be difficult to distinguish.

To train our new classifier, we will fine-tune another famous model, called ResNet18, by loading the pretrained model and training it with 120 new training images of bees and ants—a much smaller set compared to the millions of images in ImageNet.

Data Processing

Let’s begin by loading our data, defining our transforms, and configuring our dataloaders for batch sampling. As we did earlier, we’ll leverage functions from the Torchvision library for creating the datasets, loading the data, and applying the data transforms.

First let’s import the required libraries for this example:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models
from torchvision import transforms

Then we’ll download the data that we’ll use for training and validation:

from io import BytesIO
from urllib.request import urlopen
from zipfile import ZipFile

zipurl = 'https://pytorch.tips/bee-zip'
with urlopen(zipurl) as zipresp:
  with ZipFile(BytesIO(zipresp.read())) as zfile:
     zfile.extractall('./data')

Here, we use the io, urlib, and zipfile libraries to download and unzip a file to our local filesystem. After running the previous code, you should have your training and validation images in your local data/ folder. They are located in data/hymenoptera_data/train and data/hymenoptera_data/val, respectively.

Next let’s define our transforms, load the data, and configure our batch samplers.

First we’ll define our transforms:

train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.485, 0.456,0.406],
        [0.229, 0.224, 0.225])])

val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.485, 0.456, 0.406],
        [0.229, 0.224, 0.225])])

Notice that we randomly resize, crop, and flip images for training but not for validation. The “magic” numbers used in the Normalize transforms are precomputed values for the means and standard deviations.

Now let’s define the datasets:

train_dataset = datasets.ImageFolder(
            root='data/hymenoptera_data/train',
            transform=train_transforms)

val_dataset = datasets.ImageFolder(
            root='data/hymenoptera_data/val',
            transform=val_transforms)

In the previous code we used the ImageFolder dataset to pull images from our data folders and set the transforms to the ones we defined earlier. Next, we define our dataloaders for batch iteration:

train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=4,
            shuffle=True,
            num_workers=4)

val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=4,
            shuffle=True,
            num_workers=4)

We’re using a batch size of 4, and we set num_workers to 4 to configure four CPU processes to handle the parallel processing.

Now that we have prepared our training and validation data, we can design our model.

Model Design

For this example we’ll use a ResNet18 model that has been pretrained with ImageNet data. However, ResNet18 is designed to detect 1,000 classes, and in our case, we only need 2 classes—bees and ants. We can modify the final layer to detect 2 classes instead of 1,000 as shown in the following code:

model = models.resnet18(pretrained=True)

print(model.fc)
# out:
# Linear(in_features=512, out_features=1000, bias=True)

num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
print(model.fc)
# out:
# Linear(in_features=512, out_features=2, bias=True)

We first load a pretrained ResNet18 model using the function torchvision.models.resnet18(). Next, we read the number of features before the final layer with model.fc.in_features. Then we change the final layer by directly setting model.fc to a fully connected layer with two outputs.

We are going to use the pretrained model as a starting point and fine-tune its parameters with new data. Since we replaced the final linear layer, its parameters are now randomly initialized.

Now we have a ResNet18 model with all weights pretrained with ImageNet images except for the last layer. Next, we need to train our model with images of bees and ants.

Tip

Torchvision provides many famous pretrained models for computer vision and image processing, including the following:

  • AlexNet

  • VGG

  • ResNet

  • SqueezeNet

  • DenseNet

  • Inception v3

  • GoogLeNet

  • ShuffleNet v2

  • MobileNet v2

  • ResNeXt

  • Wide ResNet

  • MNASNet

For more information, explore the torchvision.models class or visit the Torchvision models documentation.

Training and Validation

Before we fine-tune our model, let’s configure our training with the following code:

from torch.optim.lr_scheduler import StepLR

device = torch.device("cuda:0" if
  torch.cuda.is_available() else "cpu") 1

model = model.to(device)
criterion = nn.CrossEntropyLoss() 2
optimizer = optim.SGD(model.parameters(),
                      lr=0.001,
                      momentum=0.9) 3
exp_lr_scheduler = StepLR(optimizer,
                          step_size=7,
                          gamma=0.1) 4
1

Move the model to a GPU if available.

2

Define our loss function.

3

Define our optimizer algorithm.

4

Use a learning rate scheduler.

The code should look familiar, with the exception of the learning rate scheduler. Here we will use a scheduler from PyTorch to adjust the learning rate of our SGD optimizer after several epochs. Using a learning rate scheduler will help our NN adjust its weights more precisely as training goes on.

The following code illustrates the entire training loop, including validation:

num_epochs=25

for epoch in range(num_epochs):

  model.train() 1
  running_loss = 0.0
  running_corrects = 0

  for inputs, labels in train_loader:
    inputs = inputs.to(device)
    labels = labels.to(device)

    optimizer.zero_grad()
    outputs = model(inputs)
    _, preds = torch.max(outputs,1)
    loss = criterion(outputs, labels)

    loss.backward()
    optimizer.step()

    running_loss += loss.item()/inputs.size(0)
    running_corrects += \
      torch.sum(preds == labels.data) \
        /inputs.size(0)

  exp_lr_scheduler.step() 2
  train_epoch_loss = \
    running_loss / len(train_loader)
  train_epoch_acc = \
    running_corrects / len(train_loader)

  model.eval() 3
  running_loss = 0.0
  running_corrects = 0

  for inputs, labels in val_loader:
      inputs = inputs.to(device)
      labels = labels.to(device)
      outputs = model(inputs)
      _, preds = torch.max(outputs,1)
      loss = criterion(outputs, labels)

      running_loss += loss.item()/inputs.size(0)
      running_corrects += \
        torch.sum(preds == labels.data) \
            /inputs.size(0)

  epoch_loss = running_loss / len(val_loader)
  epoch_acc = \
    running_corrects.double() / len(val_loader)
  print("Train: Loss: {:.4f} Acc: {:.4f}"
    " Val: Loss: {:.4f}"
    " Acc: {:.4f}".format(train_epoch_loss,
                          train_epoch_acc,
                          epoch_loss,
                          epoch_acc))
1

Training loop.

2

Schedule the learning rate for next the epoch of training.

3

Validation loop.

We should see the training and validation loss decrease while the accuracies improve. The results may bounce around a little.

Testing and Deployment

Let’s test our model and deploy it by saving the model to a file. To test our model, we’ll display a batch of images and show how our model classified them, as shown in the following code:

import matplotlib.pyplot as plt

def imshow(inp, title=None): 1
    inp = inp.numpy().transpose((1, 2, 0)) 2
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean 3
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)

inputs, classes = next(iter(val_loader)) 4
out = torchvision.utils.make_grid(inputs)
class_names = val_dataset.classes

outputs = model(inputs.to(device)) 5
_, preds = torch.max(outputs,1) 6

imshow(out, title=[class_names[x] for x in preds]) 7
1

Define a new function to plot images from our tensor images.

2

Switch from C × H × W to H × W × C image formats for plotting.

3

Undo the normalization we do during transforms so we can properly view images.

4

Grab a batch of images from our validation dataset.

5

Perform classification using our fine-tuned ResNet18.

6

Take the “winning” class.

7

Display the input images and their predicted classes.

Since we have such a small dataset, we simply test the model by visualizing the output to make sure the images match the labels. Figure 4-1 shows an example test. Your results will vary since the val_loader will return a randomly sampled batch of images.

“Results of Image Classification”
Figure 4-1. Results of image classification

When we are done, we save the model:

torch.save(model.state_dict(), "./resnet18.pt")

You can use this reference design for other cases of transfer learning, not only with image classification but with other types of data as well. As long as you can find a suitable pretrained model, you will be able to modify the model and retrain only a portion of it with a small amount of data.

This example was based on the "Transfer Learning for Computer Vision Tutorial" by Sasank Chilamkurthy. You can find more details in the tutorial.

Next, we’ll venture into the field of NLP and explore a reference design that processes text data.

Sentiment Analysis with Torchtext

Another popular deep learning application is sentiment analysis, in which people classify a block of text data. In this example, we will train an NN to predict whether a movie review is either positive or negative using the well-known Internet Movie Database (IMDb) dataset. Sentiment analysis of IMDb data is a common beginner example for learning NLP.

Data Processing

The IMDb dataset consists of 25,000 movie reviews from IMDb that are labeled by sentiment (e.g., positive or negative). The PyTorch project includes a library called Torchtext that provides convenient capabilities for performing deep learning on text data. To begin our example reference design, we will use Torchtext to load and preprocess the IMDb dataset.

Before we load the dataset, we will define a function called generate_bigrams() that we’ll use to preprocess our text review data. The model that we’ll use for this example computes n-grams of an input sentence and appends them to the end. We’ll use bi-grams, which are pairs of words or tokens that appear in a sentence.

The following code shows our preprocessing function, generate_bigrams(), and provides an example of how it works:

def generate_bigrams(x):
  n_grams = set(zip(*[x[i:] for i in range(2)]))
  for n_gram in n_grams:
    x.append(' '.join(n_gram))
  return x

generate_bigrams([
        'This', 'movie', 'is', 'awesome'])
# out:
# ['This', 'movie', 'is', 'awesome', 'This movie',
#  'movie is', 'is awesome']

Now that we have defined our preprocessing function, we can build our IMDb datasets as shown in the following code:

from torchtext.datasets import IMDB
from torch.utils.data.dataset import random_split

train_iter, test_iter = IMDB(
    split=('train', 'test')) 1

train_dataset = list(train_iter) 2
test_data = list(test_iter)

num_train = int(len(train_dataset) * 0.70)
train_data, valid_data = \
    random_split(train_dataset,
        [num_train,
         len(train_dataset) - num_train]) 3
1

Load data from IMDb dataset.

2

Redefine iterators as lists.

3

Split training data into two sets, 70% for training and 30% for validation.

In the code, we load the training and test datasets using the IMDB class. We then use the random_split() function to break the training data into two smaller sets for training and validation.

Warning

The Torchtext API changed significantly in PyTorch 1.8. Be sure you are using at least Torchtext 0.9 when running the code.

Let’s take a quick look at the data:

print(len(train_data), len(valid_data),
  len(test_data))
# out:17500 7500 25000

data_index = 21
print(train_data[data_index][0])
# out: (your results may vary)
#   pos

print(train_data[data_index][1])
# out: (your results may vary)
# ['This', 'film', 'moved', 'me', 'beyond', ...

As you can see, our datasets have 17,500 reviews for training, 7,500 for validation, and 25,000 for testing. We also printed out the 21st review and its sentiment, as shown in the output. The splits are randomly sampled, so your results may be different.

Next we need to convert our text data into numerical data so that an NN can process it. We do this by creating preprocessing functions and a data pipeline. The data pipeline will use our generate_bigrams() function, a tokenizer, and a vocabulary, as shown in the following code:

from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import Vocab

tokenizer = get_tokenizer('spacy') 1
counter = Counter()
for (label, line) in train_data:
    counter.update(generate_bigrams(
        tokenizer(line))) 2
vocab = Vocab(counter,
              max_size = 25000,
              vectors = "glove.6B.100d",
              unk_init = torch.Tensor.normal_,) 3
1

Define our tokenizer (how to break up text).

2

Make a list of all the tokens used in our training data and count how many times each occurs.

3

Create a vocabulary (list of possible tokens) and define how tokens are converted to numbers.

In the code, we define the instructions for converting text to tensors. For the review text, we specify spaCy as the tokenizer. spaCy is a popular Python package for NLP and includes its own tokenizer. A tokenizer breaks text into components like words and punctuation marks.

We also create a vocabulary and an embedding. A vocabulary is just a set of words that we can use. If we find a word in the movie review that is not contained in the vocabulary, we set the word to a special word called “unknown.” We limit our dictionary to 25,000 words, much smaller than the full set of words in the English language.

We also specify our vocabulary vectors, which causes us to download a pretrained embedding called GloVe (Global Vectors for Word Representation) with 100 dimensions. It may take several minutes to download the GloVe data and create a vocabulary.

An embedding is a method to map a word or series of words to a numeric vector. Defining a vocabulary and an embedding is a complex topic and is beyond the scope of this book. For this example, we’ll just build a vocabulary from our training data and download the popular pretrained GloVe embedding.

Now that we have defined our tokenizer and vocabulary, we can build our data pipelines for the review and label text data, as shown in the following code:

text_pipeline = lambda x: [vocab[token]
    for token in generate_bigrams(tokenizer(x))]

label_pipeline = lambda x: 1 if x=='pos' else 0

print(text_pipeline('the movie was horrible'))
# out:

print(label_pipeline('neg'))
# out:

We use lambda functions to pass text data through the pipeline so that PyTorch dataloaders can convert each text review to a 100-element vector.

Now that we have defined our datasets and preprocessing, we can create our dataloaders. Our dataloaders load batches of data from a sampling of the dataset and preprocess the data, as in the following code:

from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

device = torch.device("cuda" if
    torch.cuda.is_available() else "cpu")

def collate_batch(batch):
    label_list, text_list = [], []
    for (_label, _text) in batch:
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(
                           text_pipeline(_text))
        text_list.append(processed_text)
    return (torch.tensor(label_list,
          dtype=torch.float64).to(device),
          pad_sequence(text_list,
                       padding_value=1.0).to(device))

batch_size = 64
def batch_sampler():
    indices = [(i, len(tokenizer(s[1])))
                for i, s in enumerate(train_dataset)]
    random.shuffle(indices)
    pooled_indices = []
    # create pool of indices with similar lengths
    for i in range(0, len(indices), batch_size * 100):
        pooled_indices.extend(sorted(
          indices[i:i + batch_size * 100], key=lambda x: x[1]))

    pooled_indices = [x[0] for x in pooled_indices]

    # yield indices for current batch
    for i in range(0, len(pooled_indices),
      batch_size):
        yield pooled_indices[i:i + batch_size]

BATCH_SIZE = 64

train_dataloader = DataLoader(train_data,
                  # batch_sampler=batch_sampler(),
                  collate_fn=collate_batch,
                  batch_size=BATCH_SIZE,
                  shuffle=True)
                  # collate_fn=collate_batch)
valid_dataloader = DataLoader(valid_data,
                  batch_size=BATCH_SIZE,
                  shuffle=True,
                  collate_fn=collate_batch)
test_dataloader = DataLoader(test_data,
                  batch_size=BATCH_SIZE,
                  shuffle=True,
                  collate_fn=collate_batch)

In the code, we set the batch size to 64 and use a GPU if available. We also define a collation function called collate_batch() and pass it into our dataloaders to execute our data pipelines.

Now that we have configured our pipelines and dataloaders, let’s define our model.

Model Design

For this example we will use a model called FastText from the paper “Bag of Tricks for Efficient Text Classification” by Armand Joulin et al. While many sentiment analysis models use RNNs, this model uses a simpler approach instead.

The following code implements the FastText model:

import torch.nn as nn
import torch.nn.functional as F

class FastText(nn.Module):
    def __init__(self,
                 vocab_size,
                 embedding_dim,
                 output_dim,
                 pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(
            vocab_size,
            embedding_dim,
            padding_idx=pad_idx)
        self.fc = nn.Linear(embedding_dim,
                            output_dim)

    def forward(self, text):
        embedded = self.embedding(text)
        embedded = embedded.permute(1, 0, 2)
        pooled = F.avg_pool2d(
            embedded,
            (embedded.shape[1], 1)).squeeze(1)
        return self.fc(pooled)

As you can see, the model calculates the word embedding for each word using the nn.Embedded layer, and then it calculates the average of all the word embeddings with the avg_pool2d() function. Finally, it feeds the average through a linear layer. Refer to the paper for more details on this model.

Let’s build our model with its appropriate parameters using the following code:

model = FastText(
            vocab_size = len(vocab),
            embedding_dim = 100,
            output_dim = 1,
            pad_idx = vocab['<PAD>'])

Rather than train our embedding layer from scratch, we’ll initialize the layer’s weights with pretrained embeddings. This process is similar to how we used pretrained weights in the transfer learning example in “Image Classification with Transfer Learning”:

pretrained_embeddings = vocab.vectors 1
model.embedding.weight.data.copy_(
                    pretrained_embeddings) 2

EMBEDDING_DIM = 100
unk_idx = vocab['<UNK>'] 3
pad_idx = vocab['<PAD>']
model.embedding.weight.data[unk_idx] = \
      torch.zeros(EMBEDDING_DIM)          4
model.embedding.weight.data[pad_idx] = \
      torch.zeros(EMBEDDING_DIM)
1

Load the pretrained embedding from our vocabulary.

2

Initialize the embedding layer’s weights.

3

Initialize the embedding weights of an unknown token to zero.

4

Initialize the embedding weights of a pad token to zero.

Now that it’s initialized properly, we can train our model.

Training and Validation

The training and validation process should look familiar. It’s similar to the one we’ve used in previous examples. First we configure our loss function and our optimizer algorithm, as shown in the following code:

import torch.optim as optim

optimizer = optim.Adam(model.parameters())
criterion = nn.BCEWithLogitsLoss()

model = model.to(device)
criterion = criterion.to(device)

In this example, we are using the Adam optimizer and the BCEWithLogitsLoss() loss function. The Adam optimizer is a replacement for SGD and performs better for sparse or noisy gradients. The BCEWithLogitsLoss() function is commonly used for binary classification. We also move our model to a GPU if available.

Next we run our training and validation loops, as shown in the following code:

for epoch in range(5):
  epoch_loss = 0
  epoch_acc = 0

  model.train()
  for label, text, _ in train_dataloader:
      optimizer.zero_grad()
      predictions = model(text).squeeze(1)
      loss = criterion(predictions, label)

      rounded_preds = torch.round(
          torch.sigmoid(predictions))
      correct = \
        (rounded_preds == label).float()
      acc = correct.sum() / len(correct)

      loss.backward()
      optimizer.step()
      epoch_loss += loss.item()
      epoch_acc += acc.item()

  print("Epoch %d Train: Loss: %.4f Acc: %.4f" %
          (epoch,
          epoch_loss / len(train_dataloader),
          epoch_acc / len(train_dataloader)))

  epoch_loss = 0
  epoch_acc = 0
  model.eval()
  with torch.no_grad():
    for label, text, _ in valid_dataloader:
      predictions = model(text).squeeze(1)
      loss = criterion(predictions, label)

      rounded_preds = torch.round(
          torch.sigmoid(predictions))
      correct = \
        (rounded_preds == label).float()
      acc = correct.sum() / len(correct)

      epoch_loss += loss.item()
      epoch_acc += acc.item()

  print("Epoch %d Valid: Loss: %.4f Acc: %.4f" %
          (epoch,
          epoch_loss / len(valid_dataloader),
          epoch_acc / len(valid_dataloader)))

# out: (your results may vary)
# Epoch 0 Train: Loss: 0.6523 Acc: 0.7165
# Epoch 0 Valid: Loss: 0.5259 Acc: 0.7474
# Epoch 1 Train: Loss: 0.5935 Acc: 0.7765
# Epoch 1 Valid: Loss: 0.4571 Acc: 0.7933
# Epoch 2 Train: Loss: 0.5230 Acc: 0.8257
# Epoch 2 Valid: Loss: 0.4103 Acc: 0.8245
# Epoch 3 Train: Loss: 0.4559 Acc: 0.8598
# Epoch 3 Valid: Loss: 0.3828 Acc: 0.8549
# Epoch 4 Train: Loss: 0.4004 Acc: 0.8813
# Epoch 4 Valid: Loss: 0.3781 Acc: 0.8675

We should see validation accuracies around 85–90% with only five epochs of training. Let’s see how our model performs against the test dataset.

Testing and Deployment

Earlier, we constructed our test_iterator based on the IMDb test dataset. Recall that none of the data in the test dataset has been used for training or validation.

Our test loop is shown in the following code:

test_loss = 0
test_acc = 0
model.eval() 1
with torch.no_grad(): 1
  for label, text, _ in test_dataloader:
    predictions = model(text).squeeze(1)
    loss = criterion(predictions, label)

    rounded_preds = torch.round(
        torch.sigmoid(predictions))
    correct = \
      (rounded_preds == label).float()
    acc = correct.sum() / len(correct)

    test_loss += loss.item()
    test_acc += acc.item()

print("Test: Loss: %.4f Acc: %.4f" %
        (test_loss / len(test_dataloader),
        test_acc / len(test_dataloader)))
# out: (your results will vary)
#   Test: Loss: 0.3821 Acc: 0.8599
1

Not necessary for this model, but good practice.

In the preceding code, we process one batch at a time and cumulate the accuracy over the entire test dataset. You should get 85–90% accuracy on the test set as well.

Next we’ll predict the sentiment of our own reviews, using the following code:

import spacy
nlp = spacy.load('en_core_web_sm')

def predict_sentiment(model, sentence):
    model.eval()
    text = torch.tensor(text_pipeline(
      sentence)).unsqueeze(1).to(device)
    prediction = torch.sigmoid(model(text))
    return prediction.item()

sentiment = predict_sentiment(model,
                  "Don't waste your time")
print(sentiment)
# out: 4.763594888613835e-34

sentiment = predict_sentiment(model,
                  "You gotta see this movie!")
print(sentiment)
# out: 0.941755473613739

A result close to 0 corresponds to a negative review, while an output close to 1 indicates a positive review. As you can see, the model correctly predicted the sentiment of the sample review. Try it with some of your own movie reviews!

Finally, we’ll save our model for deployment as shown in the following code:

torch.save(model.state_dict(), 'fasttext-model.pt')

In this example, you learned how to preprocess text data and designed a FastText model for sentiment analysis. You also trained the model, evaluated its performance, and saved the model for deployment. You can use this design pattern and reference code to solve other sentiment analysis problems in your own work.

This example was based on the “Faster Sentiment Analysis” tutorial by Ben Trevett. You can find more details and other great Torchtext tutorials in his PyTorch Sentiment Analysis GitHub repository.

Let’s move on to our final reference design, in which we will use deep learning and PyTorch to generate image data.

Generative Learning—Generating Fashion-MNIST Images with DCGAN

One of the most interesting areas of deep learning is generative learning, in which NNs are used to create data. Sometimes these NNs can create images, music, text, and time series data so well that it is difficult to tell the difference between real data and the generated data. Generative learning is used to create images of people and places that don’t exist, increase image resolution, predict frames in video, augment datasets, generate news articles, and convert styles of art and music.

In this section, I’ll show you how to use PyTorch for generative learning. The development process is similar to the previous examples; however, here we’ll use an unsupervised approach in which the data is not labeled.

In addition, we’ll design and train a GAN, which is quite different from the models and training loops of previous examples. Testing and evaluating the GAN involves a slightly different process as well. The overall development sequence is consistent with the process in Chapter 2, but each part will be unique to generative learning.

In this example, we will train a GAN to generate images similar to the training images used in the Fashion-MNIST dataset. Fashion-MNIST is a popular academic dataset used for image classification that includes images of articles of clothing. Let’s access the Fashion-MNIST data to get an idea of what these images look like, and then we’ll create some synthetic images based on what we’ve seen.

Data Processing

Unlike models used for supervised learning, where the model learns the relationships between data and labels, generative models look to learn the distribution of the training data so as to generate data similar to the training data at hand. Therefore, in this example we only need training data, because if we build a good model and train it long enough, the model should begin to produce good synthetic data.

First let’s import the required libraries, define some constants, and set our device:

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

CODING_SIZE = 100
BATCH_SIZE = 32
IMAGE_SIZE = 64

device = torch.device("cuda:0" if
  torch.cuda.is_available() else "cpu")

The following code loads the training data, defines the transforms, and creates a dataloader for batch iteration:

transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
])

dataset = datasets.FashionMNIST(
                './',
                train=True,
                download=True,
                transform=transform)

dataloader = DataLoader(
                dataset,
                batch_size=BATCH_SIZE,
                shuffle=True,
                num_workers=8)

This code should look familiar to you. We are once again using Torchvision functions to define the transforms, create a dataset, and set up a dataloader that will sample the dataset, apply transforms, and return a batch of images for our model.

We can display a batch of images with the following code:

from torchvision.utils import make_grid
import matplotlib.pyplot as plt

data_batch, labels_batch = next(iter(dataloader))
grid_img = make_grid(data_batch, nrow=8)
plt.imshow(grid_img.permute(1, 2, 0))

Torchvision provides a nice utility called make_grid to display a grid of images. Figure 4-2 shows an example batch of Fashion-MNIST images.

“FashionMNIST Images”
Figure 4-2. Fashion-MNIST images

Let’s see what model we’ll use for our data generation task.

Model Design

To generate new image data, we’ll use a GAN. The goal of the GAN model is to generate “fake” data based on the training data’s distribution. The GAN accomplishes this goal with two distinct modules: the generator and the discriminator.

The job of the generator is to generate fake images that look real. The job of the discriminator is to correctly identify whether an image is fake. Although the design of GANs is beyond the scope of this book, I’ll provide a sample reference design using a deep convolutional GAN, or DCGAN.

Note

GANs were first described in the famous paper by Ian Goodfellow et al. in 2014 titled “Generative Adversarial Nets”. Guidelines for building more stable convolutional GANs were proposed by Alec Radford et al. in the 2015 paper “Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks”. This paper describes the DCGAN used in this example.

The generator is designed to create an image from an input vector of 100 random values. Here’s the code:

import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, coding_sz):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(coding_sz,
                               1024, 4, 1, 0),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.ConvTranspose2d(1024,
                               512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512,
                               256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256,
                               128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128,
                               1, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, input):
        return self.net(input)

netG = Generator(CODING_SIZE).to(device)

This example generator uses 2D convolutional transpose layers with batch normalization and ReLU activations. The layers are defined in the __init__() function. It works like our image classification models, except in reverse order.

That is, instead of reducing an image to a smaller representation, it takes a random vector and creates a full image from it. We also instantiate the Generator module as netG.

Next, we create the Discriminator module, as shown in the following code:

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,
              self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 128, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 1024, 4, 2, 1),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2),
            nn.Conv2d(1024, 1, 4, 1, 0),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.net(input)

netD = Discriminator().to(device)

The discriminator is a binary classification network that determines the probability that the input image is real. This example discriminator NN uses 2D convolutional layers with batch normalization and leaky ReLU activation functions. We instantiate the Discriminator as netD.

The authors of the DCGAN paper found that it helps to initialize the weights as shown in the following code:

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

netG.apply(weights_init)
netD.apply(weights_init)

Now that we have designed our two modules, we can set up and train the GAN.

Training

Training a GAN is somewhat more complicated than the previous training examples. In each epoch, we will first train the discriminator with a real batch of data, then use the generator to create a fake batch, and then train the discriminator with the generated fake batch of data. Lastly, we will train the generator NN to produce better fakes.

This is a good example of how powerful PyTorch is when creating custom training loops. It provides the flexibility to develop and implement new ideas with ease.

Before we start training, we need to define the loss function and optimizers that will be used to train the generator and the discriminator:

from torch import optim

criterion = nn.BCELoss()

optimizerG = optim.Adam(netG.parameters(),
                        lr=0.0002,
                        betas=(0.5, 0.999))
optimizerD = optim.Adam(netD.parameters(),
                        lr=0.0001,
                        betas=(0.5, 0.999))

In the preceding code, we define a label for real versus fake images. Then we use the binary cross entropy (BCE) loss function, which is commonly used for binary classification. Remember the discriminator is performing binary classification by classifying an image as real or fake. We use the commonly used Adam optimizer for updating the model parameters.

Let’s define values for the real and fake labels and create tensors for computing the loss:

real_labels = torch.full((BATCH_SIZE,),
                       1.,
                       dtype=torch.float,
                       device=device)

fake_labels = torch.full((BATCH_SIZE,),
                       0.,
                       dtype=torch.float,
                       device=device)

Before we start training, we will create lists for storing the errors and define a test vector to show the results later:

G_losses = []
D_losses = []
D_real = []
D_fake = []

z = torch.randn((
    BATCH_SIZE, 100)).view(-1, 100, 1, 1).to(device)
test_out_images = []

Now we can execute the training loop. If the GAN is stable, it should improve as more epochs are trained. The training loop is shown in the following code:

N_EPOCHS = 5

for epoch in range(N_EPOCHS):
  print(f'Epoch: {epoch}')
  for i, batch in enumerate(dataloader):
    if (i%200==0):
      print(f'batch: {i} of {len(dataloader)}')

    # Train Discriminator with an all-real batch.
    netD.zero_grad()
    real_images = batch[0].to(device) *2. - 1.
    output = netD(real_images).view(-1) 1
    errD_real = criterion(output, real_labels)
    D_x = output.mean().item()

    # Train Discriminator with an all-fake batch.
    noise = torch.randn((BATCH_SIZE,
                         CODING_SIZE))
    noise = noise.view(-1,100,1,1).to(device)
    fake_images = netG(noise)
    output = netD(fake_images).view(-1) 2
    errD_fake = criterion(output, fake_labels)
    D_G_z1 = output.mean().item()
    errD = errD_real + errD_fake
    errD.backward(retain_graph=True) 3
    optimizerD.step()

    # Train Generator to generate better fakes.
    netG.zero_grad()
    output = netD(fake_images).view(-1) 4
    errG = criterion(output, real_labels) 5
    errG.backward() 6
    D_G_z2 = output.mean().item()
    optimizerG.step()

    # Save losses for plotting later.
    G_losses.append(errG.item())
    D_losses.append(errD.item())

    D_real.append(D_x)
    D_fake.append(D_G_z2)

  test_images = netG(z).to('cpu').detach() 7
  test_out_images.append(test_images)
1

Pass real images to the Discriminator.

2

Pass fake images to the Discriminator.

3

Run backpropagation and update the Discriminator.

4

Pass fake images to the updated Discriminator.

5

The Generator loss is based on cases in which the[.keep-together] Discriminator is wrong.

6

Run backpropagation and update the Generator.

7

Create a batch of images and save them after each epoch.

As we’ve done in the previous examples, we loop through all the data, one batch at a time, using the dataloader during each epoch. First we train the discriminator with a batch of real images so it can compute the output, calculate the loss, and compute the gradients. Then we train the discriminator with a batch of fake images.

The fake images are created by the generator from a vector of random values. Again, we compute the discriminator output, calculate the loss, and compute the gradients. Next, we add the gradients from all the real and all the fake batches and apply backpropagation.

We compute the outputs from the freshly trained discriminator using the same fake data, and compute the loss or error of the generator. Using this loss, we compute the gradients and apply backpropagation on the generator itself.

Lastly, we’ll keep track of the loss after each epoch to see if the GAN’s training is consistently improving and stable. Figure 4-3 shows the loss curve for both the generator and the discriminator during training.

“GAN Training Curves”
Figure 4-3. GAN training curves

The loss curves plot the generator and the discriminator loss for each batch over all epochs, so the loss bounces around depending on the computed loss of the batch. We can see though that the loss in both cases has been reduced from the beginning of training. If we trained over more epochs, we’d look for these loss values to approach zero.

In general, GANs are tricky to train, and the learning rate, betas, and other optimizer hyperparameters can have a major impact.

Let’s examine the average results of the discriminator for each batch over all the epochs, as shown in Figure 4-4.

“Discriminator Results”
Figure 4-4. Discriminator results

If the GAN was perfect, the discriminator would not be able to correctly identify fake images as fake or real images as real, and we would expect the average error to be 0.5 in both cases. The results show that some batches are close to 0.5, but we can certainly do better.

Now that we have trained our network, let’s see how well it does at creating fake images of clothing.

Testing and Deployment

During supervised learning, we usually set aside a test dataset that has not been used to train or validate the model. In generative learning, there are no labels produced by the generator. We could pass our generated images into a Fashion-MNIST classifier, but there’s no way for us to know if the errors are caused by the classifier or the GAN unless we hand-label the outputs.

For now, let’s test and evaluate our GAN by comparing the results from the first epoch with the generated images from the last epoch. We create a test vector, z, for testing and use the computed generator results at the end of each epoch in our training loop code.

Figure 4-5 shows the generated images from the first epoch, while Figure 4-6 shows the results after training only five epochs.

“Generator Results (First Epoch)”
Figure 4-5. Generator results (first epoch)
“Generator Results (Last Epoch)”
Figure 4-6. Generator results (last epoch)

You can see that the generator has improved some. Look at the boot at the end of the second row or the shirt at the end of the third row. Our GAN is not perfect, but it seems to be improving after just five epochs. Training over more epochs or improving our design might produce even better results.

Finally, we can save our trained model for deployment and use it to generate more synthetic Fashion-MNIST images using the following code:

torch.save(netG.state_dict(), './gan.pt')

We expanded our PyTorch deep learning capabilities by designing and training a GAN in this generative learning reference design. You can use this reference design to create and train other GAN models and test their performance at generating new data.

In this chapter, we covered additional examples to show a variety of data processing, model design, and training approaches with PyTorch—but what if you have an amazing idea for some new, innovative NN? Or what if you come up with a new optimization algorithm or loss function that nobody’s seen before? In the next chapter, I’ll show you how to create your own custom modules and functions so you can expand your deep learning research and experiment with new ideas.

Get PyTorch Pocket Reference now with the O’Reilly learning platform.

O’Reilly members experience books, live events, courses curated by job role, and more from O’Reilly and nearly 200 top publishers.