
154 |
第
5
章
train()
函数(将在本节后面展示)改编自
PyTorch
教程
Transfer Learning
for Computer Vision
注
6
中提供的示例。
我们首先对
model.state_dict()
中的所有预训练模型权重和变量
best_
model_weights
进行深度拷贝。我们将模型的最佳精度初始化为
0.0
。
然后,代码将迭代多个
epoch
。对于每个
epoch
,我们首先加载将用于训练的
数据和标签,然后将其移动到
GPU
。在调用
model
(
inputs
)执行正向传播
之前,我们重置优化器梯度。我们使用
CrossEntropyLoss
来计算损失。完成
这些步骤后,我们将调用
loss.backward()
进行反向传播。然后,我们使用
optimizer.step()
更新所有相关参数。
在
train()
函数中,你会注意到我们在验证阶段关闭了梯度计算。这是因为
在验证阶段,不需要梯度计算。你只需要使用验证输入来计算损失和精度。
我们检查当前
epoch
的验证精度是否比前一个
epoch
有所提高。如果验证
精度有所提高,我们将结果存储在
best_model_weights
中,并设置
best_
accuracy
以表示迄今为止观察到的最佳验证精度:
def train(model, criterion, optimizer, scheduler, num_epochs=10):
#
用于存储训练和验证损失
training_loss = []
val_loss = []
best_model_weights ...