在PyTorch中,如果不对网络参数进行显式初始化,各层会使用其默认的初始化方法。不同层类型的初始化策略有所不同,以下是常见层的默认初始化方式:
1. 全连接层 (nn.Linear
)
• 权重初始化:使用Kaiming均匀分布(He初始化),假设激活函数为Leaky ReLU(负斜率a=sqrt(5)
)。初始化范围根据输入维度(fan_in
)计算,公式为:
bound = 1 fan_in \text{bound} = \frac{1}{\sqrt{\text{fan\_in}}} bound=fan_in1
权重从均匀分布 ( U(-\text{bound}, \text{bound}) ) 中采样。
• 偏置初始化:均匀分布在 ([-1/\sqrt{\text{fan_in}}, 1/\sqrt{\text{fan_in}}]) 之间。
2. 卷积层 (nn.Conv2d
, nn.Conv1d
, nn.Conv3d
)
• 权重初始化:与全连接层类似,使用Kaiming均匀分布,但fan_in
计算为输入通道数乘以卷积核面积(例如,对于Conv2d
,fan_in = in_channels * kernel_height * kernel_width
)。
• 偏置初始化:与全连接层相同,均匀分布在 ([-1/\sqrt{\text{fan_in}}, 1/\sqrt{\text{fan_in}}]) 之间。
3. LSTM/GRU层 (nn.LSTM
, nn.GRU
)
• 权重初始化:权重从均匀分布 ( U(-1/\sqrt{\text{hidden_size}}, 1/\sqrt{\text{hidden_size}}) ) 中采样。
• 偏置初始化:偏置分为两部分,一部分初始化为零,另一部分从均匀分布 ( U(-1/\sqrt{\text{hidden_size}}, 1/\sqrt{\text{hidden_size}}) ) 中采样。
4. 批归一化层 (nn.BatchNorm1d
, nn.BatchNorm2d
)
• 缩放参数(weight):初始化为1。
• 偏移参数(bias):初始化为0。
5. 嵌入层 (nn.Embedding
)
• 权重初始化:从正态分布 ( N(0, 1) ) 中采样。
默认初始化的潜在问题
• 激活函数不匹配:Kaiming初始化默认假设使用Leaky ReLU(a=sqrt(5)
),若使用ReLU或其他激活函数,可能需要手动调整初始化方式以避免梯度不稳定。
• 深层网络训练:默认初始化在较浅网络中表现良好,但在深层网络中可能需要更精细的初始化(如Xavier或正交初始化)。
代码示例:查看默认初始化
import torch.nn as nn# 定义层
linear = nn.Linear(100, 50)
conv = nn.Conv2d(3, 16, kernel_size=3)
lstm = nn.LSTM(input_size=10, hidden_size=20)# 打印权重范围和标准差
def print_init_info(module):for name, param in module.named_parameters():if 'weight' in name:print(f"{name} mean: {param.data.mean():.4f}, std: {param.data.std():.4f}, range: [{param.data.min():.4f}, {param.data.max():.4f}]")elif 'bias' in name:print(f"{name} mean: {param.data.mean():.4f}")print("Linear层初始化信息:")
print_init_info(linear)print("\nConv2d层初始化信息:")
print_init_info(conv)print("\nLSTM层初始化信息:")
print_init_info(lstm)
手动初始化推荐
若默认初始化不适用,可手动初始化以适配激活函数:
# 针对ReLU的Kaiming初始化
for module in model.modules():if isinstance(module, (nn.Linear, nn.Conv2d)):nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')if module.bias is not None:nn.init.zeros_(module.bias)elif isinstance(module, nn.LSTM):for name, param in module.named_parameters():if 'weight' in name:nn.init.xavier_uniform_(param)elif 'bias' in name:nn.init.zeros_(param)