Skip to Content
PyTorch 口袋参考手册
book

PyTorch 口袋参考手册

by Joe Papa
May 2025
Intermediate to advanced
310 pages
3h 16m
Chinese
O'Reilly Media, Inc.
Content preview from PyTorch 口袋参考手册

第 5 章 定制 PyTorch 定制 PyTorch

本作品已使用人工智能进行翻译。欢迎您提供反馈和意见:translation-feedback@oreilly.com

到目前为止,你一直在使用内置的 PyTorch 类、函数和库来设计和训练各种预定义的模型、模型层和激活函数。但是,如果你有新奇的想法,或者你正在进行前沿的 Deep Learning 研究,该怎么办呢?也许你发明了一种全新的层架构或激活函数。也许你已经开发出一种新的优化算法或一种前所未见的特殊损失函数。

在本章中,我将向你展示如何在 PyTorch 中创建自己的自定义深度学习组件和算法。我们将首先探索如何创建自定义层和激活函数,然后了解如何将这些组件组合成自定义模型架构。接下来,我将向你展示如何创建自己的损失函数和优化算法。最后,我们将了解如何创建用于训练、验证和测试的自定义循环。

PyTorch 提供了灵活性:你可以扩展现有的库,也可以将你的定制组合成你自己的库或包。通过创建自定义组件,你可以解决新的深度学习问题,加快训练速度,并发现进行深度学习的创新方法。

让我们开始创建一些自定义的 Deep Learning 层和激活函数。

自定义图层和激活

PyTorch 提供了大量内置层和激活函数。然而,PyTorch 之所以如此受欢迎,尤其是在研究界,是因为它可以非常容易地创建自定义层和激活函数。这样做的能力可以促进实验并加快研究速度。

如果我们看一下 PyTorch 的源代码,就会发现层和激活是通过功能定义和类实现创建的。功能定义说明了如何根据输入创建输出。它定义在nn.functional 模块中。类实现用于创建一个对象,该对象的核心是调用该函数,但它还包括从nn.Module 类派生的附加功能。

例如,让我们看看完全连接的nn.Linear 层是如何实现的。下面的代码显示了功能定义的简化版本,nn.functional.linear()

import torch

def linear(input, weight, bias=None):

    if input.dim() == 2 and bias is not None:
        # fused op is marginally faster
        ret = torch.addmm(bias, input, weight.t())
    else:
        output = input.matmul(weight.t())
        if bias is not None:
            output += bias
        ret = output
    return ret

linear() 函数将输入张量与权重矩阵相乘,可选择添加偏置向量,并以张量形式返回结果。您可以看到,该代码已针对性能进行了优化。当输入有两个维度且没有偏置时,应使用融合矩阵加法函数 ,因为在这种情况下速度更快。torch.addmm()

将数学计算保存在单独的函数定义中的好处是可以将优化与层nn.Module 分离开来。在编写一般代码时,也可将函数定义作为独立函数使用

不过,我们经常会使用nn.Module 类来对我们的 NN 进行子类化。当我们创建nn.Module 子类时,我们将获得nn.Module 对象的所有内置优势。在这种情况下,我们从nn.Module 派生nn.Linear 类,如下代码所示:

import torch.nn as nn
from torch import Tensor

class Linear(nn.Module):

    def __init__(self, in_features ...
Become an O’Reilly member and get unlimited access to this title plus top books and audiobooks from O’Reilly and nearly 200 top publishers, thousands of courses curated by job role, 150+ live events each month,
and much more.

Read now

Unlock full access

More than 5,000 organizations count on O’Reilly

AirBnbBlueOriginElectronic ArtsHomeDepotNasdaqRakutenTata Consultancy Services

QuotationMarkO’Reilly covers everything we've got, with content to help us build a world-class technology community, upgrade the capabilities and competencies of our teams, and improve overall team performance as well as their engagement.
Julian F.
Head of Cybersecurity
QuotationMarkI wanted to learn C and C++, but it didn't click for me until I picked up an O'Reilly book. When I went on the O’Reilly platform, I was astonished to find all the books there, plus live events and sandboxes so you could play around with the technology.
Addison B.
Field Engineer
QuotationMarkI’ve been on the O’Reilly platform for more than eight years. I use a couple of learning platforms, but I'm on O'Reilly more than anybody else. When you're there, you start learning. I'm never disappointed.
Amir M.
Data Platform Tech Lead
QuotationMarkI'm always learning. So when I got on to O'Reilly, I was like a kid in a candy store. There are playlists. There are answers. There's on-demand training. It's worth its weight in gold, in terms of what it allows me to do.
Mark W.
Embedded Software Engineer

You might also like

金融人工智能:用Python实现AI量化交易

金融人工智能:用Python实现AI量化交易

Yves Hilpisch
容器安全

容器安全

Liz Rice

Publisher Resources

ISBN: 9798341658790