February 2018
Intermediate to advanced
262 pages
6h 59m
English
We need to make some minor changes to the fit method to accommodate the three input-values generated from the data loader. The following code implements the new fit function:
def fit(epoch,model,data_loader,phase='training',volatile=False): if phase == 'training': model.train() if phase == 'validation': model.eval() volatile=True running_loss = 0.0 running_correct = 0 for batch_idx , (data1,data2,data3,target) in enumerate(data_loader): if is_cuda: data1,data2,data3,target = data1.cuda(),data2.cuda(),data3.cuda(),target.cuda() data1,data2,data3,target = Variable(data1,volatile),Variable(data2,volatile),Variable(data3,volatile),Variable(target) if phase == 'training': optimizer.zero_grad() output = model(data1,data2,data3) ...