PyTorch 梯度计算详解:以 detach
示例为例
在深度学习中,梯度计算是训练模型的核心步骤,而 PyTorch 通过自动微分(autograd)模块实现了高效的梯度求解。本文将通过一个实际代码示例,详细讲解 PyTorch 的梯度计算过程,包括 backward()
函数的作用及工作原理、grad
属性的含义,以及如何分离计算图避免梯度传播。
示例代码
import torch# 定义张量
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)# 定义计算
y = x * 2
z = y.detach() # 分离 z,z 不会参与反向传播
w = z ** 2# 反向传播
w.sum().backward()# 打印梯度
print("x 的梯度:", x.grad)
# RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
1. 梯度计算的核心概念
1.1 什么是梯度?
梯度是一个标量或张量的导数,表示函数的变化率。在机器学习中,梯度用来衡量损失函数相对于参数的变化方向和大小,以指导参数更新。
1.2 计算图
PyTorch 会在执行张量操作时构建一棵动态的计算图,每个张量都是一个节点,操作(如加法、乘法)是连接这些节点的边。计算图的作用是记录操作的顺序和依赖关系,从而实现反向传播。
1.3 requires_grad
属性
- 当
requires_grad=True
时,PyTorch 会为该张量记录梯度。 - 默认情况下,
requires_grad=False
,意味着该张量不会参与梯度计算。
2. 示例中梯度的具体计算过程
2.1 代码解析
-
定义张量
x
:x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
x
是一个可求梯度的张量,构成了计算图的起点。 -
定义计算
y
:y = x * 2
此时,
y
的计算被记录在计算图中,且与x
有依赖关系。公式为:
y = [ 2.0 , 4.0 , 6.0 ] y = [2.0, 4.0, 6.0] y=[2.0,4.0,6.0] -
分离
z
:z = y.detach()
使用
detach()
分离z
后,z
不再记录计算图。虽然z
的值等于y
,但它和y
的计算历史已断开。 -
计算
w
:w = z ** 2
w
的计算与z
相关,因为z
已被分离,它的计算不会影响原始计算图。 -
反向传播:
w.sum().backward()
由于
w
是通过分离后的z
计算而来,反向传播不会更新x
的梯度。
2.2 梯度的实际计算
让我们一步步分析梯度是如何通过计算图传播的。由于 z
被分离,原始计算图如下:
x (requires_grad=True) → y = x * 2
y
是通过x * 2
得到的,因此其梯度可以表示为:
∂ y ∂ x = 2 \frac{\partial y}{\partial x} = 2 ∂x∂y=2
但是,由于 z = y.detach()
分离了计算图,z
和 x
没有任何依赖关系,因此 梯度不会计算到 x
。
最终 x.grad
输出为 None
。
3. backward()
和 grad
的介绍
3.1 backward()
函数
backward()
是 PyTorch 用于计算梯度的核心函数。它从计算图的末端开始,沿着图的依赖关系,逐层向前计算每个张量的梯度。
使用方法:
loss.backward()
-
工作原理:
- 首先计算损失对每个变量的梯度。
- 根据链式法则,逐层累积梯度。
- 将计算结果存储在相关张量的
grad
属性中。
-
参数说明:
retain_graph
:是否保留计算图。默认为False
,计算完梯度后会释放计算图。create_graph
:是否创建计算图,允许对梯度再次求导。
3.2 grad
属性
grad
存储了张量的梯度,是反向传播的结果。- 只有
requires_grad=True
的张量才会有grad
属性。 - 如果张量未参与任何梯度计算,其
grad
属性为None
。
4. 示例代码的修改与改进
为了让梯度正确传递,我们可以移除 detach()
操作:
改进代码:
import torch# 定义张量
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)# 定义计算
y = x * 2
z = y # 不使用 detach
w = z ** 2# 反向传播
w.sum().backward()# 打印梯度
print("x 的梯度:", x.grad) # 输出:x 的梯度: tensor([4., 8., 12.])
计算过程:
-
y = x * 2
:
∂ y ∂ x = 2 \frac{\partial y}{\partial x} = 2 ∂x∂y=2 -
w = z^2 = y^2
:
∂ w ∂ y = 2 y \frac{\partial w}{\partial y} = 2y ∂y∂w=2y -
利用链式法则:
∂ w ∂ x = ∂ w ∂ y ⋅ ∂ y ∂ x = 2 y ⋅ 2 = 4 x \frac{\partial w}{\partial x} = \frac{\partial w}{\partial y} \cdot \frac{\partial y}{\partial x} = 2y \cdot 2 = 4x ∂x∂w=∂y∂w⋅∂x∂y=2y⋅2=4x最终,
x.grad = [4.0, 8.0, 12.0]
。
5. 注意事项
-
detach 和 no_grad 的区别
detach
是对单个张量操作,将其从计算图中分离。torch.no_grad()
是一个上下文管理器,用于禁止其内所有计算图的创建,常用于推理阶段。
-
梯度累积
默认情况下,PyTorch 会累积梯度(即多次调用backward()
会叠加梯度)。如果不需要累积,可在每次计算前手动清零:optimizer.zero_grad()
-
链式法则与梯度传播
PyTorch 的自动微分基于链式法则,因此每一步的梯度都会被准确传播。
总结
本文通过一个实际示例,详细解析了 PyTorch 中梯度的计算过程,重点介绍了计算图的概念以及 backward()
和 grad
的使用。理解这些核心机制对于深度学习模型的开发与调试至关重要,希望这篇文章能帮助你更深入地掌握 PyTorch 的梯度计算。
后记
2024年12月13日10点24分于上海,在GPT4o大模型辅助下完成。