
596 Глава 17
Функция
d_train()
для обучения дискриминатора не требует изменения размера вход-
ного изображения:
>>> def d_train(x):
... disc_model.zero_grad()
... # Обучение диск риминатора на р еальном пакет е
... batch_size = x.size(0)
... x = x.to(device)
... d_labels_real = torch.ones(batch_size, 1, device=device)
... d_proba_real = disc_model(x)
... d_loss_real = loss_fn(d_proba_real, d_labels_real)
... # Обучение диск риминатора на ф иктивном паке те
... input_z = create_noise(batch_size, z_size, mode_z).to(device)
... g_output = gen_model(input_z)
... d_proba_fake = disc_model(g_output)
... d_labels_fake = torch.zeros(batch_size, ...