
606 Глава 17
... d_generated = disc_model(g_output)
... g_loss = -d_generated.mean()
...
... # обратное расп ространение гра диента и опти мизация
... # ТОЛЬКО параме тров генератора
... g_loss.backward()
... g_optimizer.step()
... return g_loss.data.item()
Затем мы обучим модель в течение 100 эпох и запишем выходные данные генератора
для фиксированного входного шума:
>>> epoch_samples_wgan = []
>>> lambda_gp = 10.0
>>> num_epochs = 100
>>> torch.manual_seed(1)
>>> critic_iterations = 5
>>> for epoch in range(1, num_epochs+1):
... gen_model.train()
... d_losses, g_losses = [], []
... for i, (x, _) in enumerate(mnist_dl):
... for _ in ra