First, let's create a new Model class that serves as a wrapper for different models and provides the one-stop training API. Create a new file named build_gan.py and import the necessary modules:
import osimport numpy as npimport torchimport torchvision.utils as vutilsfrom cgan import Generator as cganGfrom cgan import Discriminator as cganD
Then, let's create the Model class. In this class, we will initialize the Generator and Discriminator modules and provide train and eval methods so that users can simply call Model.train() (or Model.eval()) somewhere else to complete the model training (or evaluation):
class Model(object): def __init__(self, name, device, data_loader, classes, channels, img_size, latent_dim): ...