- We start by importing the necessary libraries, as follows:
import matplotlib.pyplot as pltimport itertools import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.autograd import Variable from torch.utils.data.dataset import Dataset import torchvision.datasets as dset import torchvision.transforms as transforms
- We then define our discriminator network in a function:
class discriminator(nn.Module): def __init__(self): super(discriminator, self).__init__() self.conv1 = nn.Conv2d(1, d, 4, 2, 1, bias=False) self.conv2 = nn.Conv2d(d, d*2, 4, 2, 1, bias=False) self.conv2_bn = nn.BatchNorm2d(*2) self.conv3 = nn.Conv2d(d*2, *4, 4, 2, 1, bias=False) self.conv3_bn = nn.BatchNorm2d ...