Model training and evaluation

Now, we'll create the build_gan.py file. As usual, we'll begin with the imports:

import itertoolsimport osimport timefrom datetime import datetimeimport numpy as npimport torchimport torchvision.utils as vutilsimport utilsfrom cyclegan import Generator as cycGfrom cyclegan import Discriminator as cycD

We'll need a function to initialize the weights:

def _weights_init(m):    classname = m.__class__.__name__    if classname.find('Conv') != -1:        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)    elif classname.find('BatchNorm') != -1:        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)        torch.nn.init.constant_(m.bias.data, 0.0)

Now, we will create the Model class:

class Model(object):    def __init__(self,                 name,                 device, data_loader, ...

Get Hands-On Generative Adversarial Networks with PyTorch 1.x now with the O’Reilly learning platform.

O’Reilly members experience books, live events, courses curated by job role, and more from O’Reilly and nearly 200 top publishers.