Skip to Main Content
弱监督学习实用指南
book

弱监督学习实用指南

by Wee Hyong Tok, Amit Bahree, Senja Filipi
January 2023
Beginner to intermediate content levelBeginner to intermediate
209 pages
3h 55m
Chinese
Southeast University Press
Content preview from 弱监督学习实用指南
154
5
train()
函数(将在本节后面展示)改编自
PyTorch
教程
Transfer Learning
for Computer Vision
6
中提供的示例。
我们首先对
model.state_dict()
中的所有预训练模型权重和变量
best_
model_weights
进行深度拷贝。我们将模型的最佳精度初始化为
0.0
然后,代码将迭代多个
epoch
。对于每个
epoch
,我们首先加载将用于训练的
数据和标签,然后将其移动到
GPU
。在调用
model
inputs
执行正向传播
之前,我们重置优化器梯度。我们使用
CrossEntropyLoss
来计算损失。完成
这些步骤后,我们将调用
loss.backward()
进行反向传播。然后,我们使用
optimizer.step()
更新所有相关参数。
train()
函数中,你会注意到我们在验证阶段关闭了梯度计算。这是因为
在验证阶段,不需要梯度计算。你只需要使用验证输入来计算损失和精度。
我们检查当前
epoch
的验证精度是否比前一个
epoch
有所提高。如果验证
精度有所提高,我们将结果存储在
best_model_weights
中,并设置
best_
accuracy
以表示迄今为止观察到的最佳验证精度:
def train(model, criterion, optimizer, scheduler, num_epochs=10):
#
用于存储训练和验证损失
training_loss = []
val_loss = []
best_model_weights ...
Become an O’Reilly member and get unlimited access to this title plus top books and audiobooks from O’Reilly and nearly 200 top publishers, thousands of courses curated by job role, 150+ live events each month,
and much more.
Start your free trial

You might also like

人工智能技术与大数据

人工智能技术与大数据

Posts & Telecom Press, Anand Deshpande, Manish Kumar
解密金融数据

解密金融数据

Justin Pauley
C++语言导学(原书第2版)

C++语言导学(原书第2版)

本贾尼 斯特劳斯特鲁普

Publisher Resources

ISBN: 9787576602630