One-stop model training API

First, let's create a new Model class that serves as a wrapper for different models and provides the one-stop training API. Create a new file named build_gan.py and import the necessary modules:

import osimport numpy as npimport torchimport torchvision.utils as vutilsfrom cgan import Generator as cganGfrom cgan import Discriminator as cganD

Then, let's create the Model class. In this class, we will initialize the Generator and Discriminator modules and provide train and eval methods so that users can simply call Model.train() (or Model.eval()) somewhere else to complete the model training (or evaluation):

class Model(object):    def __init__(self,                 name,                 device,                 data_loader,                 classes,                 channels,                 img_size, latent_dim): ...

Get Hands-On Generative Adversarial Networks with PyTorch 1.x now with the O’Reilly learning platform.

O’Reilly members experience books, live events, courses curated by job role, and more from O’Reilly and nearly 200 top publishers.