Now, let's create our generator and discriminator networks. We put the code in the same simple_gan.py file as well:
- Define the parameters of the generator network:
class Generator(object): def __init__(self): self.z = None self.w1 = weight_initializer(Z_DIM, G_HIDDEN) self.b1 = weight_initializer(1, G_HIDDEN) self.x1 = None self.w2 = weight_initializer(G_HIDDEN, G_HIDDEN) self.b2 = weight_initializer(1, G_HIDDEN) self.x2 = None self.w3 = weight_initializer(G_HIDDEN, X_DIM) self.b3 = weight_initializer(1, X_DIM) self.x3 = None self.x = None
We keep track of the inputs and outputs of all the layers because we need them to calculate the derivatives to update the parameters later.
- Define the forward ...