In this section, we'll implement a pre-activation ResNet to classify the CIFAR-10 images using PyTorch 1.3.1 and torchvision 0.4.2. Let's start:
- As usual, we'll start with the imports. Note that we'll use the shorthand F for the PyTorch functional module (https://pytorch.org/docs/stable/nn.html#torch-nn-functional):
import matplotlib.pyplot as pltimport torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimimport torchvisionfrom torchvision import transforms
- Next, let's define the pre-activation regular (non-bottleneck) residual block. We'll implement it as nn.Module—the base class for all neural network modules. Let's start with the class definition and the __init__ method: