5.2 参数管理
每个网络都由各层组成,一个网络模块中的层可由索引访问
net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))
print(net[2])
输出:
Linear(in_features=8, out_features=1, bias=True)
5.2.1 参数访问
网络中的参数一般是指各层权重和偏置
若想访问某层的参数,用层来调用state_dict()函数
print(net[2].state_dict())
输出:
OrderedDict([('weight', tensor([[-0.3343, 0.3289, -0.0063, 0.0594, -0.1051, -0.3419, 0.2796, 0.0557]])), ('bias', tensor([0.3026]))])
可直接对各层参数进行调用
由于需要目标函数对参数求梯度进行优化,所以需要记录梯度
所以各层的参数也具有属性grad,梯度初始化为None
net[2].weight.grad == None
5.2.1.2 一次性访问所有参数
使用named_parameters()访问所有参数
各层或者整个模型都可以调用
print(*[(name, param.shape) for name, param in net[0].named_parameters()]) print([(name, param.shape) for name, param in net.named_parameters()])
输出:
('weight', torch.Size([8, 4])) ('bias', torch.Size([8]))
[('0.weight', torch.Size([8, 4])), ('0.bias', torch.Size([8])), ('2.weight', torch.Size([1, 8])), ('2.bias', torch.Size([1]))]
注:
print
函数的 *
操作符用于将列表中的每个元组作为独立的参数传递给 print
,这样 print
函数会直接打印列表中的元组