torch.nn.ModuleList
是 PyTorch 中的一个容器类,用于存放多个子模块(nn.Module
),并将这些子模块注册为当前模块的属性。与 Python 原生的 list
类似,它可以存储一组模块,但是与普通的 Python 列表不同,ModuleList
中的模块会被自动识别为模型的一部分,因此可以参与参数的优化,并且在保存和加载模型时会自动处理这些子模块的状态。
主要功能
- 模块容器: 用于存储
nn.Module
子模块的列表。 - 自动注册参数:
ModuleList
中的模块会被自动添加到当前nn.Module
对象的参数中,并在调用model.parameters()
时包含这些模块的参数。 - 支持灵活索引: 可以像操作普通 Python 列表一样,通过索引访问和迭代这些模块。
使用场景
ModuleList
通常用于存放一系列相似的模块,适用于网络的不同层,例如在需要多层 nn.Linear
、nn.Conv2d
或者循环神经网络的情况下。
主要方法
append(module)
: 向列表中添加模块。extend(modules)
: 使用可迭代对象扩展ModuleList
。- 索引访问: 可以通过下标来访问存储的模块。
示例
1. 简单使用
import torch
import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()# 使用 ModuleList 存放多个 Linear 层self.linears = nn.ModuleList([nn.Linear(10, 10) for _ in range(5)])def forward(self, x):# 对每个 Linear 层进行前向传播for layer in self.linears:x = layer(x)return xmodel = MyModel()
input_tensor = torch.randn(1, 10)
output = model(input_tensor)
print(output)
在上面的例子中,ModuleList
存放了 5 个 nn.Linear
层,它们会参与模型参数的更新和优化。
2. 动态添加模块
import torch
import torch.nn as nnclass MyDynamicModel(nn.Module):def __init__(self):super(MyDynamicModel, self).__init__()self.layers = nn.ModuleList()def add_layer(self, layer):# 动态添加层self.layers.append(layer)def forward(self, x):for layer in self.layers:x = layer(x)return xmodel = MyDynamicModel()
# 动态添加层
model.add_layer(nn.Linear(10, 20))
model.add_layer(nn.ReLU())
model.add_layer(nn.Linear(20, 5))input_tensor = torch.randn(1, 10)
output = model(input_tensor)
print(output)
在这个例子中,我们可以在运行时动态添加层到 ModuleList
中。
与 Sequential
的区别
ModuleList
: 仅仅是一个模块的列表,存放了子模块,但不会定义前向传播过程。你需要自己手动调用每个模块。Sequential
: 也是一个模块容器,但它会定义好前向传播逻辑,将所有子模块按顺序连接起来,前向传播时会依次调用这些模块。
总结
nn.ModuleList
用于灵活地存储一组nn.Module
子模块,并且会将这些子模块的参数自动注册到模型中,方便模型训练和保存。- 它适合需要灵活管理或动态构建神经网络层的场景。