Transfer learning example with PyTorch

Now that we know what transfer learning is, let's see whether it works in practice. In this section, we'll apply an advanced ImageNet pre-trained network on the CIFAR-10 images. We'll use both types of transfer learning. It's preferable to run this example on GPU:

  1. Do the following imports:
import torchimport torch.nn as nnimport torch.optim as optimimport torchvisionfrom torchvision import models, transforms
  1. Define batch_size for convenience:
batch_size = 50
  1. Define the training dataset. We have to consider a few things:
    • The CIFAR-10 images are 32 x 32, while the ImageNet network expects 224 x 224 input. As we are using ImageNet based network, we'll upsample the 32x32 CIFAR images to 224x224.

Get Python Deep Learning - Second Edition now with the O’Reilly learning platform.

O’Reilly members experience books, live events, courses curated by job role, and more from O’Reilly and nearly 200 top publishers.