序号 | 系列文章 |
---|---|
1 | 深度学习训练中GPU内存管理 |
2 | 深度学习PyTorch之数据加载DataLoader |
3 | 深度学习 PyTorch 中 18 种数据增强策略与实现 |
4 | 深度学习pytorch之简单方法自定义9类卷积即插即用 |
5 | 深度学习PyTorch之13种模型精度评估公式及调用方法 |
6 | 深度学习pytorch之4种归一化方法(Normalization)原理公式解析 |
7 | 深度学习pytorch之19种优化算法(optimizer)解析 |
8 | 深度学习pytorch之22种损失函数数学公式和代码定义 |
9 | DIY损失函数–以自适应边界损失为例 |
10 | 深度学习PyTorch之动态计算图可视化 - 使用 torchviz 生成计算图 |
文章目录
- 前言
- 1. 什么是动态计算图?
- 2. 为什么要可视化计算图?
- 3. 使用 `torchviz` 生成计算图
- 3.1 安装 `torchviz`
- 3.2 生成计算图完整代码示例
- 3.3 在训练过程中生成计算图
- 3.4 代码解读
- 3.5 生成的计算图
- 4. `torchviz` 的更多应用
- 5. 总结
- 参考文献
前言
在深度学习模型的开发过程中,理解和可视化模型的计算图对于调试、优化和教学都具有重要意义。PyTorch 采用的是动态图机制,这使得每次前向传播时计算图都被动态创建。而 torchviz
是一个非常有用的工具,它可以将这些动态图转化为可视化图形,帮助我们更直观地理解模型的计算过程。在本篇博客中,我们将重点介绍如何使用 torchviz
生成和保存 PyTorch 模型的计算图,并结合实际训练代码进行展示。
1. 什么是动态计算图?
在 PyTorch 中,计算图并不是在模型初始化时构建好的,而是通过前向传播过程动态地构建的。这种动态特性意味着每次运行时,计算图会根据输入数据的形状和大小而变化,因此我们可以灵活地进行调试和优化。PyTorch 的动态图提供了较高的灵活性,允许在计算图中嵌入复杂的控制流结构(例如循环和条件判断)。
2. 为什么要可视化计算图?
可视化计算图的优势在于:
- 调试:通过查看每一层的输入输出,可以快速发现模型设计上的问题。
- 优化:通过分析计算图,可以识别瓶颈和不必要的计算,进而优化模型性能。
- 教学:对于新手来说,计算图能够帮助他们理解深度学习模型的前向传播过程。
虽然 PyTorch 的动态图功能非常强大,但由于它不提供直接的计算图展示方式,因此我们需要借助外部工具 torchviz
进行可视化。
3. 使用 torchviz
生成计算图
torchviz
是一个能够将 PyTorch 计算图转化为图形的库,具体来说,它能够将计算图渲染为 DOT
格式并生成可视化图像文件(如 PNG 或 PDF)。我们通过以下几步可以生成计算图:
3.1 安装 torchviz
首先,你需要安装 torchviz
库。可以通过 pip
安装:
pip install torchviz
此时会直接将graphviz,torchziv两个都安装好,但是这种方法无法将graphviz导入系统路径。出现报错graphviz.backend.ExecutableNotFound: failed to execute ‘dot‘, make sure the Graphviz executables are***,需要从网址 Download | Graphviz下载graphviz的zip格式文件,解压后复制到以下python路径下即可。
3.2 生成计算图完整代码示例
核心语句只包括make_dot和render两个函数,其中:
- make_dot(y) 会根据输入张量 y 的计算过程生成计算图。
- render(“model_graph”, format=“png”) 将计算图保存为 PNG 图片。
import torch
import torch.nn as nn
import torch.optim as optim
from torchviz import make_dot# 定义一个简单的神经网络
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(2, 2)self.fc2 = nn.Linear(2, 1)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 创建模型实例
model = SimpleNN()# 输入数据
x = torch.randn(1, 2)# 前向传播
y = model(x)# 可视化计算图
dot = make_dot(y, params=dict(model.named_parameters()))
dot.render("model_graph", format="png") # 保存图像为png文件
复制以上代码运行后生成model_graph.png如
3.3 在训练过程中生成计算图
假设你已经有了一个标准的 PyTorch 训练代码,并且希望在训练过程中生成计算图。我们可以在每次前向传播时使用 torchviz.make_dot
来生成计算图,并保存为 PNG 文件。
以下是一个集成计算图生成的训练代码示例:
import torch
from torchviz import make_dotfor epo in range(epo_num):print(epo)train_loss = 0train_acc = 0.0seg_model.train()for index, (img, label) in enumerate(train_dataloader):img = img.to(device)label = label.to(device)optimizer.zero_grad()output = seg_model(img) # 得到模型输出# 使用 torch.sigmoid 激活函数output = torch.sigmoid(output)# 生成计算图并保存为 PNG 文件if index == 0: # 只在第一个batch时生成计算图dot = make_dot(output, params=dict(seg_model.named_parameters()))dot.render("model_graph_epoch_{}_batch_{}".format(epo, index), format="png") # 保存为 epoch_x_batch_y.png# 计算损失loss = criterion(output, label)loss.backward()iter_loss = loss.item()all_train_iter_loss.append(iter_loss)train_loss += iter_lossoptimizer.step()# 计算准确率output_1 = output.argmax(dim=1)label_1 = label.argmax(dim=1)correct = torch.eq(output_1, label_1).sum().item()iter_acc = correct / label_1.numel()all_train_iter_acc.append(iter_acc)train_acc += iter_acc
3.4 代码解读
-
前向传播:
output = seg_model(img)
这一行执行了前向传播,计算了模型的输出。 -
计算图生成:在每个 epoch 的第一个 batch 中,使用
make_dot(output, params=dict(seg_model.named_parameters()))
来生成计算图。output
是模型的输出,而seg_model.named_parameters()
则提供了模型的参数信息,这对于生成完整的计算图非常有帮助。 -
保存计算图:通过
dot.render()
将计算图保存为 PNG 格式的文件。文件名包含当前的 epoch 和 batch 索引,以便于区分。dot.render("model_graph_epoch_{}_batch_{}".format(epo, index), format="png")
3.5 生成的计算图
计算图会包含模型中的每个操作(如矩阵乘法、加法等),以及这些操作之间的连接关系。通过计算图(以下示例),你可以清楚地看到模型的每一步计算如何进行。
4. torchviz
的更多应用
除了在训练过程中生成计算图,torchviz
还可以用于以下场景:
-
单步调试:如果你的模型非常复杂,可以在某个特定步骤(如单个前向传播)生成计算图,帮助调试。
-
模型设计:在设计新的网络架构时,通过生成计算图,可以确保每一层的输入输出形状是正确的。
-
计算性能分析:通过分析计算图中的每个节点,可以识别出性能瓶颈并进行优化。
5. 总结
PyTorch 的动态图特性使得每次前向传播时计算图都是动态生成的,而 torchviz
则提供了一个简便的工具,可以将这些动态生成的计算图可视化为图像文件。通过将 torchviz
集成到训练代码中,我们可以在训练过程中实时生成计算图,这不仅有助于我们调试模型,还可以为教学和研究提供更清晰的解释。
参考文献
- torchviz GitHub
- PyTorch 官方文档