Now, we'll create the build_gan.py file. As usual, we'll begin with the imports:
import itertoolsimport osimport timefrom datetime import datetimeimport numpy as npimport torchimport torchvision.utils as vutilsimport utilsfrom cyclegan import Generator as cycGfrom cyclegan import Discriminator as cycD
We'll need a function to initialize the weights:
def _weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: torch.nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find('BatchNorm') != -1: torch.nn.init.normal_(m.weight.data, 1.0, 0.02) torch.nn.init.constant_(m.bias.data, 0.0)
Now, we will create the Model class:
class Model(object): def __init__(self, name, device, data_loader, ...