Model training and evaluation

Now, we'll create the build_gan.py file. As usual, we'll begin with the imports:

import itertoolsimport osimport timefrom datetime import datetimeimport numpy as npimport torchimport torchvision.utils as vutilsimport utilsfrom cyclegan import Generator as cycGfrom cyclegan import Discriminator as cycD

We'll need a function to initialize the weights:

def _weights_init(m):    classname = m.__class__.__name__    if classname.find('Conv') != -1:        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)    elif classname.find('BatchNorm') != -1:        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)        torch.nn.init.constant_(m.bias.data, 0.0)

Now, we will create the Model class:

class Model(object):    def __init__(self,                 name,                 device, data_loader, ...

Get Hands-On Generative Adversarial Networks with PyTorch 1.x now with O’Reilly online learning.

O’Reilly members experience live online training, plus books, videos, and digital content from 200+ publishers.