We need to make some adjustments to the training API so that we can make use of the class and style vectors for attribute extraction and image generation.
First, we add several imported modules in the build_gan.py file:
import itertoolsfrom infogan import Generator as infoganGfrom infogan import Discriminator as infoganD
The default weight initialization provided by PyTorch easily leads to saturation, so we need a custom weight initializer:
def _weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: torch.nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find('BatchNorm') != -1: torch.nn.init.normal_(m.weight.data, 1.0, 0.02) torch.nn.init.constant_(m.bias.data, 0.0)