您的位置:首页 > 娱乐 > 明星 > 长沙网页设计_陈林设计事务所_网络广告推广服务_搭建一个网站

长沙网页设计_陈林设计事务所_网络广告推广服务_搭建一个网站

2025/1/15 15:08:55 来源:https://blog.csdn.net/shizheng_Li/article/details/144443672  浏览:    关键词:长沙网页设计_陈林设计事务所_网络广告推广服务_搭建一个网站
长沙网页设计_陈林设计事务所_网络广告推广服务_搭建一个网站

PyTorch 梯度计算详解:以 detach 示例为例

在深度学习中,梯度计算是训练模型的核心步骤,而 PyTorch 通过自动微分(autograd)模块实现了高效的梯度求解。本文将通过一个实际代码示例,详细讲解 PyTorch 的梯度计算过程,包括 backward() 函数的作用及工作原理、grad 属性的含义,以及如何分离计算图避免梯度传播。


示例代码

import torch# 定义张量
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)# 定义计算
y = x * 2
z = y.detach()  # 分离 z,z 不会参与反向传播
w = z ** 2# 反向传播
w.sum().backward()# 打印梯度
print("x 的梯度:", x.grad)  
# RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

1. 梯度计算的核心概念

1.1 什么是梯度?

梯度是一个标量或张量的导数,表示函数的变化率。在机器学习中,梯度用来衡量损失函数相对于参数的变化方向和大小,以指导参数更新。

1.2 计算图

PyTorch 会在执行张量操作时构建一棵动态的计算图,每个张量都是一个节点,操作(如加法、乘法)是连接这些节点的边。计算图的作用是记录操作的顺序和依赖关系,从而实现反向传播。

1.3 requires_grad 属性
  • requires_grad=True 时,PyTorch 会为该张量记录梯度。
  • 默认情况下,requires_grad=False,意味着该张量不会参与梯度计算。

2. 示例中梯度的具体计算过程

2.1 代码解析
  1. 定义张量 x

    x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
    

    x 是一个可求梯度的张量,构成了计算图的起点。

  2. 定义计算 y

    y = x * 2
    

    此时,y 的计算被记录在计算图中,且与 x 有依赖关系。公式为:
    y = [ 2.0 , 4.0 , 6.0 ] y = [2.0, 4.0, 6.0] y=[2.0,4.0,6.0]

  3. 分离 z

    z = y.detach()
    

    使用 detach() 分离 z 后,z 不再记录计算图。虽然 z 的值等于 y,但它和 y 的计算历史已断开。

  4. 计算 w

    w = z ** 2
    

    w 的计算与 z 相关,因为 z 已被分离,它的计算不会影响原始计算图。

  5. 反向传播:

    w.sum().backward()
    

    由于 w 是通过分离后的 z 计算而来,反向传播不会更新 x 的梯度。


2.2 梯度的实际计算

让我们一步步分析梯度是如何通过计算图传播的。由于 z 被分离,原始计算图如下:

x (requires_grad=True) → y = x * 2
  • y 是通过 x * 2 得到的,因此其梯度可以表示为:
    ∂ y ∂ x = 2 \frac{\partial y}{\partial x} = 2 xy=2

但是,由于 z = y.detach() 分离了计算图,zx 没有任何依赖关系,因此 梯度不会计算到 x

最终 x.grad 输出为 None


3. backward()grad 的介绍

3.1 backward() 函数

backward() 是 PyTorch 用于计算梯度的核心函数。它从计算图的末端开始,沿着图的依赖关系,逐层向前计算每个张量的梯度。

使用方法:

loss.backward()
  • 工作原理:

    • 首先计算损失对每个变量的梯度。
    • 根据链式法则,逐层累积梯度。
    • 将计算结果存储在相关张量的 grad 属性中。
  • 参数说明:

    • retain_graph:是否保留计算图。默认为 False,计算完梯度后会释放计算图。
    • create_graph:是否创建计算图,允许对梯度再次求导。
3.2 grad 属性
  • grad 存储了张量的梯度,是反向传播的结果。
  • 只有 requires_grad=True 的张量才会有 grad 属性。
  • 如果张量未参与任何梯度计算,其 grad 属性为 None

4. 示例代码的修改与改进

为了让梯度正确传递,我们可以移除 detach() 操作:

改进代码:
import torch# 定义张量
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)# 定义计算
y = x * 2
z = y  # 不使用 detach
w = z ** 2# 反向传播
w.sum().backward()# 打印梯度
print("x 的梯度:", x.grad)  # 输出:x 的梯度: tensor([4., 8., 12.])
计算过程:
  1. y = x * 2
    ∂ y ∂ x = 2 \frac{\partial y}{\partial x} = 2 xy=2

  2. w = z^2 = y^2
    ∂ w ∂ y = 2 y \frac{\partial w}{\partial y} = 2y yw=2y

  3. 利用链式法则:
    ∂ w ∂ x = ∂ w ∂ y ⋅ ∂ y ∂ x = 2 y ⋅ 2 = 4 x \frac{\partial w}{\partial x} = \frac{\partial w}{\partial y} \cdot \frac{\partial y}{\partial x} = 2y \cdot 2 = 4x xw=ywxy=2y2=4x

    最终,x.grad = [4.0, 8.0, 12.0]


5. 注意事项

  1. detach 和 no_grad 的区别

    • detach 是对单个张量操作,将其从计算图中分离。
    • torch.no_grad() 是一个上下文管理器,用于禁止其内所有计算图的创建,常用于推理阶段。
  2. 梯度累积
    默认情况下,PyTorch 会累积梯度(即多次调用 backward() 会叠加梯度)。如果不需要累积,可在每次计算前手动清零:

    optimizer.zero_grad()
    
  3. 链式法则与梯度传播
    PyTorch 的自动微分基于链式法则,因此每一步的梯度都会被准确传播。


总结

本文通过一个实际示例,详细解析了 PyTorch 中梯度的计算过程,重点介绍了计算图的概念以及 backward()grad 的使用。理解这些核心机制对于深度学习模型的开发与调试至关重要,希望这篇文章能帮助你更深入地掌握 PyTorch 的梯度计算。

后记

2024年12月13日10点24分于上海,在GPT4o大模型辅助下完成。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com