U-MambaEnc-2d.py
# 导入必要的模块
import torch
import torch.nn as nn
import torch.nn.functional as F# 定义一个上采样层类,继承自 nn.Module
class UpsampleLayer(nn.Module):# 初始化方法,定义层的结构和所需的超参数def __init__(self, # 初始化函数的参数列表conv_op, # 卷积操作,用于在上采样后处理特征图input_channels, # 输入通道数output_channels, # 输出通道数pool_op_kernel_size, # 上采样的倍数或核大小mode='nearest' # 上采样使用的插值方法,默认使用 nearest 插值方法):# 调用父类的初始化方法super().__init__()# 定义一个卷积层,卷积核大小为1x1,输入通道为 input_channels,输出通道为 output_channelsself.conv = conv_op(input_channels, output_channels, kernel_size=1)# 存储上采样的核大小(倍数)self.pool_op_kernel_size = pool_op_kernel_size# 存储上采样模式self.mode = mode# 前向传播方法def forward(self, x):# 对输入张量 x 进行上采样操作,使用 F.interpolate 来进行插值# 上采样比例由 pool_op_kernel_size 决定,mode 决定使用的插值方式(默认为 'nearest')x = F.interpolate(x, scale_factor=self.pool_op_kernel_size, mode=self.mode)# 对上采样后的结果应用卷积操作,使用的是初始化时定义的卷积层 self.convx = self.conv(x)# 返回处理后的特征图return x
这段代码定义了一个自定义的
UpsampleLayer
类,结合了上采样(通常是通过插值)和卷积操作。它可以用于在神经网络中实现以下功能:
上采样:使用
F.interpolate
对输入特征图进行上采样操作。scale_factor
参数确定了上采样的倍数,mode
参数指定了上采样时使用的插值方法(如nearest
,即最近邻插值,或者其他如bilinear
等方式)。卷积:上采样之后,通常会应用一个卷积层来进一步处理特征图。这段代码中的卷积层使用了大小为
1x1
的卷积核,作用是改变通道数(从输入的input_channels
转换为output_channels
),同时保持空间尺寸不变。
MambaLayer
# 导入必要的模块
import torch
import torch.nn as nn
from torch.cuda.amp import autocast# 定义一个自定义的 MambaLayer 类,继承自 nn.Module
class MambaLayer(nn.Module):# 初始化方法,定义层的结构和超参数def __init__(self, dim, d_state=16, d_conv=4, expand=2, channel_token=False):super().__init__()# 打印当前的维度参数 dimprint(f"MambaLayer: dim: {dim}")# 保存输入的维度self.dim = dim# 定义 LayerNorm 层,用于对输入进行归一化self.norm = nn.LayerNorm(dim)# 初始化 Mamba 模块,传入各个参数self.mamba = Mamba(d_model=dim, # 模型的维度d_state=d_state, # SSM 状态扩展因子d_conv=d_conv, # 局部卷积的宽度expand=expand, # 模块扩展因子)# 是否使用通道作为 token(通道维度是否作为一个额外的标识符)self.channel_token = channel_token# 通过图像块 token 方式进行前向传播def forward_patch_token(self, x):B, d_model = x.shape[:2] # B: batch size, d_model: 模型维度(通常是输入特征的通道数)# 确保输入的 d_model 与该层的 dim 参数匹配assert d_model == self.dim# 获取输入的空间维度(图像的高度和宽度)n_tokens = x.shape[2:].numel() # 计算图像块的数量(空间维度的元素个数)img_dims = x.shape[2:] # 获取图像的空间尺寸# 将输入张量展开为二维形式,(batch_size, d_model, n_tokens)x_flat = x.reshape(B, d_model, n_tokens).transpose(-1, -2) # 转置,使得 tokens 在最后一维# 对展开后的张量进行归一化x_norm = self.norm(x_flat)# 使用 Mamba 模块处理归一化后的张量x_mamba = self.mamba(x_norm)# 恢复张量形状,转置回原始的空间维度并恢复为三维out = x_mamba.transpose(-1, -2).reshape(B, d_model, *img_dims)# 返回处理后的结果return out# 通过通道 token 方式进行前向传播def forward_channel_token(self, x):B, n_tokens = x.shape[:2] # B: batch size, n_tokens: 这里指的是通道的数量# 计算展平后的 d_model(即通道数乘以空间维度数)d_model = x.shape[2:].numel()# 确保展平后的 d_model 和该层的 dim 参数一致assert d_model == self.dim, f"d_model: {d_model}, self.dim: {self.dim}"# 获取图像的空间维度img_dims = x.shape[2:]# 将输入张量展平为三维形状 (B, n_tokens, d_model)x_flat = x.flatten(2) # 展平后,维度变为 (B, n_tokens, d_model)# 确保展平后的张量维度与 d_model 匹配assert x_flat.shape[2] == d_model, f"x_flat.shape[2]: {x_flat.shape[2]}, d_model: {d_model}"# 对展平后的张量进行归一化x_norm = self.norm(x_flat)# 使用 Mamba 模块处理归一化后的张量x_mamba = self.mamba(x_norm)# 恢复张量的形状,恢复为 B, n_tokens 和图像的空间维度out = x_mamba.reshape(B, n_tokens, *img_dims)# 返回处理后的结果return out# 前向传播方法,根据是否使用通道 token 来选择处理方式@autocast(enabled=False) # 关闭自动混合精度def forward(self, x):# 如果输入是 float16 类型,则将其转换为 float32 类型if x.dtype == torch.float16:x = x.type(torch.float32)# 根据 channel_token 参数的值选择使用哪种方式if self.channel_token:# 使用通道作为 token 进行前向传播out = self.forward_channel_token(x)else:# 使用图像块作为 token 进行前向传播out = self.forward_patch_token(x)# 返回输出结果return out
MambaLayer
是一个自定义的神经网络层,主要通过两种方式进行前向传播:通过图像块 token 或 通过通道 token。这个层使用了Mamba
模块和LayerNorm
层来进行数据处理。具体来说:
输入结构:
x
是一个四维张量,表示一批图像的特征图,形状通常为(batch_size, channels, height, width)
。
forward_patch_token
:
- 将输入的特征图展开成一个长向量,按图像块(patch token)处理。
- 使用
Mamba
模块处理归一化后的输入数据,并恢复为图像的空间尺寸。
forward_channel_token
:
- 使用通道维度作为 token 进行处理。将输入的通道信息作为 token,展开并进行归一化处理,然后通过
Mamba
处理。
autocast(enabled=False)
:
- 通过
autocast
装饰器控制自动混合精度,虽然此处禁用了混合精度(enabled=False
)。动态选择输入模式:
- 根据
channel_token
的布尔值选择使用哪种输入方式。如果channel_token
为True
,则使用forward_channel_token
;否则使用forward_patch_token
。
import torch
import torch.nn as nn# 定义一个基本的残差块 (Basic ResBlock),继承自 nn.Module
class BasicResBlock(nn.Module):# 初始化方法,定义该残差块的各个参数def __init__(self,conv_op, # 卷积操作(例如,nn.Conv2d)input_channels, # 输入通道数output_channels, # 输出通道数norm_op, # 正则化操作(例如,nn.BatchNorm2d)norm_op_kwargs, # 正则化操作的额外参数(如 momentum, eps 等)kernel_size=3, # 卷积核的大小,默认为 3x3padding=1, # 卷积的填充大小,默认为 1stride=1, # 卷积的步幅,默认为 1use_1x1conv=False, # 是否使用 1x1 卷积nonlin=nn.LeakyReLU, # 非线性激活函数,默认为 LeakyReLUnonlin_kwargs={'inplace': True} # 激活函数的额外参数,默认为 inplace=True):super().__init__() # 调用父类的初始化方法# 定义第一个卷积层,输入通道数为 input_channels,输出通道数为 output_channels# 卷积核大小为 kernel_size,步幅为 stride,填充为 paddingself.conv1 = conv_op(input_channels, output_channels, kernel_size, stride=stride, padding=padding)# 定义第一个正则化层,应用于 conv1 的输出,采用 norm_op 作为正则化操作self.norm1 = norm_op(output_channels, **norm_op_kwargs)# 定义第一个激活层,使用非线性激活函数 (默认为 LeakyReLU),并传递额外参数self.act1 = nonlin(**nonlin_kwargs)# 定义第二个卷积层,输入和输出通道数均为 output_channels,卷积核大小为 kernel_size,填充为 paddingself.conv2 = conv_op(output_channels, output_channels, kernel_size, padding=padding)# 定义第二个正则化层,应用于 conv2 的输出self.norm2 = norm_op(output_channels, **norm_op_kwargs)# 定义第二个激活层,使用非线性激活函数 (默认为 LeakyReLU),并传递额外参数self.act2 = nonlin(**nonlin_kwargs)# 如果需要使用 1x1 卷积,定义 conv3 层作为一个额外的 1x1 卷积,用于调整输入的通道数if use_1x1conv:self.conv3 = conv_op(input_channels, output_channels, kernel_size=1, stride=stride)else:self.conv3 = None # 如果不需要 1x1 卷积,则设置为 None# 前向传播方法def forward(self, x):# 通过第一个卷积层、归一化层和激活函数处理输入数据y = self.conv1(x)y = self.act1(self.norm1(y)) # y 经过第一个卷积层 -> norm -> 激活函数# 通过第二个卷积层和归一化层处理数据y = self.norm2(self.conv2(y)) # y 经过第二个卷积层 -> norm# 如果需要使用 1x1 卷积,则调整输入 x 的通道数if self.conv3:x = self.conv3(x) # 如果 conv3 存在,则调整输入的通道数# 残差连接:将输入 x 与卷积和激活后的输出相加y += x # 残差连接,y 和 x 相加# 最后通过第二个激活函数进行处理并返回结果return self.act2(y) # 输出经过第二个激活函数后的结果
这段代码定义了一个 基本残差块(
BasicResBlock
),它是构建深度神经网络时常用的模块之一,尤其在 ResNet(残差网络)中广泛使用。主要步骤:
- 卷积层 (
conv_op
):用于提取特征,conv1
和conv2
分别是前向传播的两层卷积。conv1
会对输入进行特征提取,conv2
则进一步处理这些特征。- 正则化层 (
norm_op
):使用批量归一化(nn.BatchNorm2d
)或其他正则化方法,帮助加速训练并提高模型稳定性。- 激活函数 (
nonlin
):通常用于引入非线性因素,LeakyReLU
用于避免 ReLU 的死神经元问题。- 1x1 卷积:可选地,通过使用
conv3
(1x1 卷积)来调整输入的通道数。这对于输入和输出的通道数不匹配的情况很有用。use_1x1conv
参数控制是否使用该层。- 残差连接:这是核心思想之一。在
forward
函数中,输入x
直接加到处理后的输出y
上,形成“跳跃连接”或者残差连接。这有助于缓解深度网络中梯度消失和梯度爆炸的问题。- 返回输出:最终,输出会通过第二个激活函数
act2
,以保证残差连接之后的特征在非线性空间中得到进一步处理。
ResidualMambaEncoder
class ResidualMambaEncoder(nn.Module):# 初始化方法,设置网络的各项参数def __init__(self,input_size: Tuple[int, ...], # 输入数据的尺寸 (如图像的宽高)input_channels: int, # 输入通道数(例如 RGB 图像为 3)n_stages: int, # 网络的阶段数,表示网络的深度或者不同分辨率的阶段features_per_stage: Union[int, List[int], Tuple[int, ...]], # 每个阶段的特征数conv_op: Type[_ConvNd], # 卷积操作类型,通常是 nn.Conv2dkernel_sizes: Union[int, List[int], Tuple[int, ...]], # 每个阶段的卷积核大小strides: Union[int, List[int], Tuple[int, ...], Tuple[Tuple[int, ...], ...]], # 每个阶段的卷积步幅n_blocks_per_stage: Union[int, List[int], Tuple[int, ...]], # 每个阶段的残差块数conv_bias: bool = False, # 是否使用卷积的偏置项,默认为 Falsenorm_op: Union[None, Type[nn.Module]] = None, # 正则化操作,默认为 Nonenorm_op_kwargs: dict = None, # 正则化操作的参数,默认为 Nonenonlin: Union[None, Type[torch.nn.Module]] = None, # 激活函数,默认为 Nonenonlin_kwargs: dict = None, # 激活函数的额外参数,默认为 Nonereturn_skips: bool = False, # 是否返回跳跃连接,默认为 Falsestem_channels: int = None, # 起始阶段的通道数,默认为 Nonepool_type: str = 'conv', # 池化操作类型,默认为 'conv'):super().__init__() # 调用父类 nn.Module 的初始化方法# 如果 kernel_sizes 是一个整数,扩展成包含 n_stages 个相同值的列表if isinstance(kernel_sizes, int):kernel_sizes = [kernel_sizes] * n_stages# 如果 features_per_stage 是一个整数,扩展成包含 n_stages 个相同值的列表if isinstance(features_per_stage, int):features_per_stage = [features_per_stage] * n_stages# 如果 n_blocks_per_stage 是一个整数,扩展成包含 n_stages 个相同值的列表if isinstance(n_blocks_per_stage, int):n_blocks_per_stage = [n_blocks_per_stage] * n_stages# 如果 strides 是一个整数,扩展成包含 n_stages 个相同值的列表if isinstance(strides, int):strides = [strides] * n_stages# 校验输入的各个列表的长度是否与 n_stages 一致assert len(kernel_sizes) == n_stages, "kernel_sizes must have as many entries as we have resolution stages (n_stages)"assert len(n_blocks_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have resolution stages (n_stages)"assert len(features_per_stage) == n_stages, "features_per_stage must have as many entries as we have resolution stages (n_stages)"assert len(strides) == n_stages, "strides must have as many entries as we have resolution stages (n_stages). " \"Important: first entry is recommended to be 1, else we run strided conv drectly on the input"# 获取池化操作函数(如果 pool_type 不等于 'conv')pool_op = get_matching_pool_op(conv_op, pool_type=pool_type) if pool_type != 'conv' else None# 初始化一个空的列表,用于存储每个阶段是否需要进行通道标记(channel token)do_channel_token = [False] * n_stages# 初始化一个空列表,用于存储每个阶段的特征图大小feature_map_sizes = []feature_map_size = input_size # 初始化特征图大小为输入图像的尺寸# 遍历每个阶段,计算每个阶段的特征图尺寸for s in range(n_stages):# 根据当前阶段的步幅 (strides[s]) 计算当前阶段特征图的尺寸feature_map_sizes.append([i // j for i, j in zip(feature_map_size, strides[s])])feature_map_size = feature_map_sizes[-1] # 更新当前的特征图尺寸# 如果当前阶段的特征图元素数目小于或等于该阶段的特征数量,则标记该阶段需要通道标记(channel token)if np.prod(feature_map_size) <= features_per_stage[s]:do_channel_token[s] = True# 输出每个阶段的特征图大小print(f"feature_map_sizes: {feature_map_sizes}")# 输出每个阶段是否需要通道标记print(f"do_channel_token: {do_channel_token}")
# 初始化一个空列表,用于存储卷积操作中每个卷积核大小对应的填充大小
self.conv_pad_sizes = []
# 遍历每个卷积核大小(kernel_sizes),计算每个卷积核的填充大小
for krnl in kernel_sizes:# 假设卷积核大小为奇数,填充为卷积核大小的一半# 这里使用整数除法来确定填充的大小,通常卷积核大小为奇数时会对称地填充self.conv_pad_sizes.append([i // 2 for i in krnl])# 获取第一个阶段的通道数,用于定义 stem 部分的输出通道数
stem_channels = features_per_stage[0]# 定义网络的 stem 部分(即最开始的特征提取层)
self.stem = nn.Sequential(# 定义第一个 BasicResBlock,作为网络的起始卷积块BasicResBlock(conv_op=conv_op, # 卷积操作类型input_channels=input_channels, # 输入通道数output_channels=stem_channels, # 输出通道数norm_op=norm_op, # 正则化操作norm_op_kwargs=norm_op_kwargs, # 正则化操作的参数kernel_size=kernel_sizes[0], # 卷积核大小padding=self.conv_pad_sizes[0], # 填充大小stride=1, # 步幅,设为 1nonlin=nonlin, # 激活函数nonlin_kwargs=nonlin_kwargs, # 激活函数的额外参数use_1x1conv=True # 是否使用 1x1 卷积),# 使用 BasicBlockD 定义更多的卷积块,数量为 n_blocks_per_stage[0] - 1*[BasicBlockD(conv_op=conv_op, # 卷积操作类型input_channels=stem_channels, # 输入通道数(stem 部分的输出通道)output_channels=stem_channels, # 输出通道数与输入通道数相同kernel_size=kernel_sizes[0], # 卷积核大小stride=1, # 步幅设为 1conv_bias=conv_bias, # 是否使用卷积偏置norm_op=norm_op, # 正则化操作norm_op_kwargs=norm_op_kwargs, # 正则化操作的参数nonlin=nonlin, # 激活函数nonlin_kwargs=nonlin_kwargs # 激活函数的额外参数) for _ in range(n_blocks_per_stage[0] - 1) # 循环生成多个卷积块]
)# 更新输入通道数为 stem 部分的输出通道数
input_channels = stem_channels# 定义一个空的列表,用于存储后续各个阶段(stages)
stages = []
mamba_layers = []# 遍历每个阶段(n_stages),为每个阶段定义卷积块
for s in range(n_stages):# 定义每个阶段的卷积块,首先是一个 BasicResBlockstage = nn.Sequential(BasicResBlock(conv_op=conv_op, # 卷积操作类型norm_op=norm_op, # 正则化操作norm_op_kwargs=norm_op_kwargs, # 正则化操作的参数input_channels=input_channels, # 输入通道数output_channels=features_per_stage[s], # 输出通道数,根据阶段的配置来设定kernel_size=kernel_sizes[s], # 卷积核大小padding=self.conv_pad_sizes[s], # 填充大小stride=strides[s], # 步幅use_1x1conv=True, # 是否使用 1x1 卷积nonlin=nonlin, # 激活函数nonlin_kwargs=nonlin_kwargs # 激活函数的额外参数),# 使用 BasicBlockD 定义更多的卷积块,数量为 n_blocks_per_stage[s] - 1*[BasicBlockD(conv_op=conv_op, # 卷积操作类型input_channels=features_per_stage[s], # 输入通道数(本阶段的输出通道数)output_channels=features_per_stage[s], # 输出通道数与输入通道数相同kernel_size=kernel_sizes[s], # 卷积核大小stride=1, # 步幅设为 1conv_bias=conv_bias, # 是否使用卷积偏置norm_op=norm_op, # 正则化操作norm_op_kwargs=norm_op_kwargs, # 正则化操作的参数nonlin=nonlin, # 激活函数nonlin_kwargs=nonlin_kwargs # 激活函数的额外参数) for _ in range(n_blocks_per_stage[s] - 1) # 循环生成多个卷积块])# 将当前阶段(stage)添加到 stages 列表中stages.append(stage)# 目前的阶段输出通道数作为下一阶段的输入通道数input_channels = features_per_stage[s]
# 判断当前阶段的索引 (s) 和网络总阶段数 (n_stages) 是否满足特定条件,
# 保证最后一个阶段有一个 MambaLayer。
if bool(s % 2) ^ bool(n_stages % 2):# 如果满足条件,则将 MambaLayer 添加到 mamba_layers 列表中。# 其中 dim 是特征图大小的乘积 (在 do_channel_token[s] 为 True 时,dim 会取决于 feature_map_sizes[s],# 否则取 features_per_stage[s]),channel_token 决定是否使用通道 token。mamba_layers.append(MambaLayer(dim = np.prod(feature_map_sizes[s]) if do_channel_token[s] else features_per_stage[s],channel_token = do_channel_token[s]))
else:# 否则,向 mamba_layers 列表添加一个 Identity 层。# Identity 层是一个恒等映射,即不改变输入数据。mamba_layers.append(nn.Identity())# 将当前阶段(stage)添加到 stages 列表中。
stages.append(stage)# 更新输入通道数为当前阶段的输出通道数,供下一个阶段使用。
input_channels = features_per_stage[s]# 将 mamba_layers 列表转化为一个 ModuleList,方便管理和参数更新。
self.mamba_layers = nn.ModuleList(mamba_layers)# 将 stages 列表转化为一个 ModuleList,方便管理和参数更新。
self.stages = nn.ModuleList(stages)# 设置网络的输出通道数为每个阶段的输出通道数(features_per_stage)。
self.output_channels = features_per_stage# 将步幅(strides)转换为列表形式,确保每个卷积层都使用正确的步幅。
self.strides = [maybe_convert_scalar_to_list(conv_op, i) for i in strides]# 设置是否返回跳跃连接(skip connections)的标志。
self.return_skips = return_skips# 存储卷积操作类型、正则化操作类型、正则化参数、激活函数等配置。
self.conv_op = conv_op
self.norm_op = norm_op
self.norm_op_kwargs = norm_op_kwargs
self.nonlin = nonlin
self.nonlin_kwargs = nonlin_kwargs# 存储是否使用卷积偏置(conv_bias)和卷积核大小(kernel_sizes)。
self.conv_bias = conv_bias
self.kernel_sizes = kernel_sizes# 这是网络的初始化部分,主要是设置网络的各个模块、操作类型和超参数等。
def forward(self, x):# 如果网络有 stem(初始卷积块),则对输入 x 进行处理。if self.stem is not None:x = self.stem(x)# 初始化一个列表,用来存储每个阶段的输出。ret = []# 遍历每个阶段,执行前向传播。for s in range(len(self.stages)):# 对输入 x 使用当前阶段的卷积操作。x = self.stages[s](x)# 使用当前阶段的 MambaLayer 进行处理。如果是 Identity 层,则不会改变输入 x。x = self.mamba_layers[s](x)# 将当前阶段的输出添加到 ret 列表中。ret.append(x)# 如果需要返回跳跃连接(skip connections),则返回 ret 列表中的所有结果。if self.return_skips:return retelse:# 否则,仅返回列表中的最后一个结果。return ret[-1]
ResidualMambaEncoder
类的作用是构建一个具有多个阶段的深度神经网络模型,其中每个阶段都包含卷积操作、正则化层、激活函数和一个定制的变换层MambaLayer
,而在某些条件下,还可以使用恒等映射层Identity
来避免任何变化。这个类的设计目标可能是用于图像特征提取或其他需要多阶段处理的任务。我们可以从代码和注释中推测出
ResidualMambaEncoder
的整体功能和设计理念。类的作用总结:
多阶段特征提取:
ResidualMambaEncoder
通过多个阶段(stages
)对输入数据进行逐步处理,每个阶段包含卷积、正则化和激活函数等层。每个阶段会学习不同层次的特征,类似于传统的卷积神经网络(CNN)架构中的不同层。MambaLayer 的应用:
MambaLayer
是一个特定的变换层,可以根据不同的条件对输入特征图进行处理。在某些阶段(依据s % 2
和n_stages % 2
的条件),MambaLayer
会被添加到网络中,用于执行额外的操作,例如通道编码或某种特征变换。否则,使用Identity
层,不对输入做任何更改。跳跃连接(Skip Connections):
ResidualMambaEncoder
支持跳跃连接,这意味着在每个阶段的输出都可以被保留并返回,以便后续使用。这在残差网络中非常常见,有助于减缓训练过程中梯度消失的问题。卷积层的构建:
- 通过
conv_op
和其他相关参数,网络的每个阶段都会包含卷积操作。每个卷积操作的步幅(stride)、卷积核大小(kernel size)等配置都是可调的,提供了灵活性。输入输出通道管理:
- 在每个阶段,输入通道和输出通道都会根据
features_per_stage
来调整,这确保了每一阶段都有正确的通道数进行特征处理。
compute_conv_feature_map_size
def compute_conv_feature_map_size(self, input_size):# 检查是否存在stem部分(通常是网络的初始卷积层)if self.stem is not None:# 如果存在stem,调用stem部分的compute_conv_feature_map_size方法计算经过stem处理后的特征图尺寸output = self.stem.compute_conv_feature_map_size(input_size)else:# 如果stem不存在,初始化output为0,表示没有初始的卷积层output = np.int64(0)# 遍历网络中的每个阶段(stage),每个阶段通常包含卷积、激活等操作for s in range(len(self.stages)):# 调用当前阶段的compute_conv_feature_map_size方法,计算该阶段后的特征图尺寸,并累加到output中output += self.stages[s].compute_conv_feature_map_size(input_size)# 更新input_size,模拟经过当前阶段后输入尺寸的变化# 通过将当前输入尺寸(input_size)与该阶段的步幅(self.strides[s])对应元素相除,来更新输入尺寸# 这表示步幅对特征图尺寸的影响,即卷积操作如何缩小特征图input_size = [i // j for i, j in zip(input_size, self.strides[s])]# 返回所有阶段累加后的特征图总尺寸return output
这段代码的作用是计算通过网络(包括
stem
和多个阶段stages
)后,输出特征图的大小。它考虑了每个阶段的卷积操作对特征图大小的影响。具体地:
- 首先计算
stem
部分(如果存在)的输出特征图大小。- 然后依次计算每个阶段(
stages
)的输出特征图大小,并根据每个阶段的步幅调整输入尺寸。- 最终返回所有阶段累加后的特征图总大小。
UNetResDecoder
class UNetResDecoder(nn.Module):# 初始化解码器def __init__(self,encoder, # 编码器,用于提供中间特征(skip connections)num_classes, # 分割任务的类别数量n_conv_per_stage: Union[int, Tuple[int, ...], List[int]], # 每个解码阶段使用的卷积层数量deep_supervision, # 是否使用深度监督nonlin_first: bool = False): # 是否先使用非线性激活super().__init__()self.deep_supervision = deep_supervision # 是否使用深度监督self.encoder = encoder # 传入的编码器self.num_classes = num_classes # 目标类别数n_stages_encoder = len(encoder.output_channels) # 编码器阶段数,根据编码器输出的通道数来判断# 如果 n_conv_per_stage 是整数,表示每个阶段的卷积层数都一样if isinstance(n_conv_per_stage, int):n_conv_per_stage = [n_conv_per_stage] * (n_stages_encoder - 1)# 确保卷积层数数组的长度与编码器的阶段数一致assert len(n_conv_per_stage) == n_stages_encoder - 1, "n_conv_per_stage must have as many entries as we have " \"resolution stages - 1 (n_stages in encoder - 1), " \"here: %d" % n_stages_encoder# 初始化阶段、上采样层和分割层列表stages = []upsample_layers = []seg_layers = []# 遍历每个解码阶段(除去最后一个阶段),构建上采样和卷积层for s in range(1, n_stages_encoder):input_features_below = encoder.output_channels[-s] # 编码器当前阶段的输入通道数input_features_skip = encoder.output_channels[-(s + 1)] # 编码器跳跃连接的输入通道数stride_for_upsampling = encoder.strides[-s] # 上采样的步幅# 创建上采样层(通过最近邻插值进行上采样)upsample_layers.append(UpsampleLayer(conv_op=encoder.conv_op, # 使用编码器的卷积操作input_channels=input_features_below, # 输入通道数output_channels=input_features_skip, # 输出通道数pool_op_kernel_size=stride_for_upsampling, # 上采样的步幅mode='nearest' # 上采样方法为最近邻))# 创建当前解码阶段的卷积层(包括基本的残差块和若干个卷积层)stages.append(nn.Sequential(BasicResBlock( # 基本的残差块,用于特征学习conv_op=encoder.conv_op,norm_op=encoder.norm_op,norm_op_kwargs=encoder.norm_op_kwargs,nonlin=encoder.nonlin,nonlin_kwargs=encoder.nonlin_kwargs,input_channels=2 * input_features_skip if s < n_stages_encoder - 1 else input_features_skip, # 中间特征拼接输入output_channels=input_features_skip, # 输出通道数kernel_size=encoder.kernel_sizes[-(s + 1)], # 卷积核大小padding=encoder.conv_pad_sizes[-(s + 1)], # 填充大小stride=1, # 步幅为1use_1x1conv=True # 是否使用1x1卷积),*[ # 使用 BasicBlockD 创建多个卷积层BasicBlockD(conv_op=encoder.conv_op,input_channels=input_features_skip,output_channels=input_features_skip,kernel_size=encoder.kernel_sizes[-(s + 1)],stride=1,conv_bias=encoder.conv_bias,norm_op=encoder.norm_op,norm_op_kwargs=encoder.norm_op_kwargs,nonlin=encoder.nonlin,nonlin_kwargs=encoder.nonlin_kwargs,) for _ in range(n_conv_per_stage[s-1] - 1) # 每个阶段有多个卷积层]))# 最后一层卷积层,用于生成分割结果seg_layers.append(encoder.conv_op(input_features_skip, num_classes, 1, 1, 0, bias=True))# 将所有阶段、上采样层和分割层分别放入 nn.ModuleListself.stages = nn.ModuleList(stages)self.upsample_layers = nn.ModuleList(upsample_layers)self.seg_layers = nn.ModuleList(seg_layers)def forward(self, skips):lres_input = skips[-1] # 最后一层的skip连接作为初始输入seg_outputs = [] # 用于存储分割结果# 对每个解码阶段进行处理for s in range(len(self.stages)):x = self.upsample_layers[s](lres_input) # 上采样当前输入if s < (len(self.stages) - 1): # 如果不是最后一个阶段x = torch.cat((x, skips[-(s+2)]), 1) # 拼接来自编码器的跳跃连接x = self.stages[s](x) # 通过当前阶段的卷积层进行特征处理if self.deep_supervision: # 如果启用了深度监督seg_outputs.append(self.seg_layers[s](x)) # 每个阶段都生成一个分割输出elif s == (len(self.stages) - 1): # 如果是最后一个阶段seg_outputs.append(self.seg_layers[-1](x)) # 只对最后一个阶段生成分割输出lres_input = x # 更新输入,供下一阶段使用seg_outputs = seg_outputs[::-1] # 将分割结果反转,保证输出顺序是从最浅到最深# 如果没有深度监督,返回第一个阶段的输出;否则,返回所有阶段的输出if not self.deep_supervision:r = seg_outputs[0]else:r = seg_outputsreturn r # 返回最终的分割结果
这段代码实现了一个 U-Net 风格的解码器部分,结合了残差网络(ResNet)和深度监督(Deep Supervision)。具体的功能包括:
上采样与特征融合:每个解码阶段先通过上采样将特征图尺寸恢复,然后与来自编码器的跳跃连接(skip connection)进行拼接,提供更多的细节信息。
残差块与卷积层:在每个解码阶段使用基本的残差块(
BasicResBlock
)进行特征提取。每个阶段还可以有多个卷积层,用于进一步精炼特征。深度监督:如果启用了深度监督,每个解码阶段都会生成一个分割结果,这些结果被返回并可以用于训练;如果没有深度监督,则只返回最后阶段的分割结果。
分割层:每个阶段的最后都会通过一个卷积层生成分割结果,最终生成每个像素的类别预测。
compute_conv_feature_map_size
def compute_conv_feature_map_size(self, input_size):# 初始化一个空列表,用来保存每个阶段(解码器部分)计算出来的特征图尺寸skip_sizes = []# 遍历编码器阶段,计算每个阶段的特征图尺寸for s in range(len(self.encoder.strides) - 1):# 对于编码器的每个阶段,根据步幅(strides)计算输入尺寸的缩小# `zip(input_size, self.encoder.strides[s])` 会将每个维度的输入尺寸和步幅配对,逐元素进行整除,得到每个维度上的特征图尺寸skip_sizes.append([i // j for i, j in zip(input_size, self.encoder.strides[s])])# 更新输入尺寸为当前阶段计算出的特征图尺寸,供下一阶段使用input_size = skip_sizes[-1]# 确保计算得到的跳跃连接的尺寸列表与解码器阶段数一致assert len(skip_sizes) == len(self.stages)# 初始化输出变量,用于累加每个阶段计算出的特征图的总大小output = np.int64(0)# 遍历解码器阶段,计算每个阶段的特征图尺寸for s in range(len(self.stages)):# 使用当前阶段的 `compute_conv_feature_map_size` 方法来计算每个卷积层的特征图尺寸,并累加output += self.stages[s].compute_conv_feature_map_size(skip_sizes[-(s+1)])# 计算当前阶段跳跃连接部分的特征图尺寸,并累加# `self.encoder.output_channels[-(s+2)]` 代表当前阶段跳跃连接的通道数# `skip_sizes[-(s+1)]` 代表跳跃连接的特征图尺寸(来自编码器)output += np.prod([self.encoder.output_channels[-(s+2)], *skip_sizes[-(s+1)]], dtype=np.int64)# 如果启用了深度监督,或者是最后一个阶段,则计算每个分割输出的特征图尺寸并累加# `self.num_classes` 是目标类别数,`skip_sizes[-(s+1)]` 是特征图的尺寸if self.deep_supervision or (s == (len(self.stages) - 1)):output += np.prod([self.num_classes, *skip_sizes[-(s+1)]], dtype=np.int64)# 返回所有卷积层输出特征图的总尺寸return output
这段代码的目的是计算整个网络中所有卷积层(包括解码器阶段、跳跃连接和分割输出)的特征图尺寸。这个计算对内存和计算需求非常关键,尤其是在需要优化内存占用或者进行硬件部署时。具体来说,代码执行了以下几个步骤:
计算跳跃连接尺寸:对于每个编码器阶段,使用步幅 (
self.encoder.strides[s]
) 来计算输入尺寸在该阶段后的缩小特征图尺寸,并将结果保存在skip_sizes
列表中。每个阶段的输入尺寸都会被更新,以供下一个阶段使用。验证跳跃连接尺寸:确保
skip_sizes
列表的长度与解码器的阶段数量一致,这样可以保证每个解码器阶段都有对应的跳跃连接尺寸。计算卷积层的特征图尺寸:对于每个解码器阶段,首先计算该阶段卷积层的输出特征图尺寸。然后,计算跳跃连接部分的特征图尺寸,并加到总尺寸中。
计算深度监督输出和分割层输出:如果启用了深度监督或者是最后一个阶段,计算分割输出的特征图尺寸(包括类别数和特征图尺寸)。每个分割层的输出大小会被加到总尺寸中。
返回总特征图尺寸:最终,所有阶段的特征图尺寸都会被累加,并返回一个总尺寸,表示整个网络的卷积层所需要的内存。
UMambaEnc
class UMambaEnc(nn.Module):def __init__(self,input_size: Tuple[int, ...], # 输入张量的尺寸 (例如 [H, W] 或 [D, H, W])input_channels: int, # 输入通道数,例如图像的通道数 (例如 1 为灰度图,3 为 RGB 图像)n_stages: int, # 网络的阶段数,也就是编码器和解码器的阶段数量features_per_stage: Union[int, List[int], Tuple[int, ...]], # 每个阶段的特征图数量conv_op: Type[_ConvNd], # 卷积操作的类型,例如标准卷积或可分离卷积kernel_sizes: Union[int, List[int], Tuple[int, ...]], # 卷积核的大小,可能是整数或列表/元组strides: Union[int, List[int], Tuple[int, ...]], # 步幅,可能是整数或列表/元组n_conv_per_stage: Union[int, List[int], Tuple[int, ...]], # 每个阶段的卷积层数量num_classes: int, # 最终的分类数(即输出类别数)n_conv_per_stage_decoder: Union[int, Tuple[int, ...], List[int]], # 解码器阶段的卷积层数量conv_bias: bool = False, # 是否使用卷积层的偏置项norm_op: Union[None, Type[nn.Module]] = None, # 批量归一化或其他归一化操作类型norm_op_kwargs: dict = None, # 归一化操作的超参数dropout_op: Union[None, Type[_DropoutNd]] = None, # Dropout 操作类型dropout_op_kwargs: dict = None, # Dropout 操作的超参数nonlin: Union[None, Type[torch.nn.Module]] = None, # 激活函数类型(例如 ReLU、LeakyReLU 等)nonlin_kwargs: dict = None, # 激活函数的超参数deep_supervision: bool = False, # 是否启用深度监督stem_channels: int = None # 初始阶段的通道数):super().__init__()# 如果 `n_conv_per_stage` 是一个整数,转换为每个阶段的卷积层数量列表n_blocks_per_stage = n_conv_per_stageif isinstance(n_blocks_per_stage, int):n_blocks_per_stage = [n_blocks_per_stage] * n_stages# 如果 `n_conv_per_stage_decoder` 是一个整数,转换为每个解码器阶段的卷积层数量列表if isinstance(n_conv_per_stage_decoder, int):n_conv_per_stage_decoder = [n_conv_per_stage_decoder] * (n_stages - 1)# 为了减小网络的最后阶段计算量,后半段(编码器部分的后面)每个阶段的卷积层数量设置为 1for s in range(math.ceil(n_stages / 2), n_stages):n_blocks_per_stage[s] = 1 # 对解码器部分的最后几层设置卷积层数量为 1for s in range(math.ceil((n_stages - 1) / 2 + 0.5), n_stages - 1):n_conv_per_stage_decoder[s] = 1# 校验 `n_blocks_per_stage` 和 `n_conv_per_stage_decoder` 的长度是否与 `n_stages` 匹配assert len(n_blocks_per_stage) == n_stages, "n_blocks_per_stage must have as many entries as we have " \f"resolution stages. here: {n_stages}. " \f"n_blocks_per_stage: {n_blocks_per_stage}"assert len(n_conv_per_stage_decoder) == (n_stages - 1), "n_conv_per_stage_decoder must have one less entries " \f"as we have resolution stages. here: {n_stages} " \f"stages, so it should have {n_stages - 1} entries. " \f"n_conv_per_stage_decoder: {n_conv_per_stage_decoder}"# 构建编码器部分,`ResidualMambaEncoder` 负责提取特征并输出跳跃连接self.encoder = ResidualMambaEncoder(input_size, # 输入尺寸input_channels, # 输入通道数n_stages, # 网络阶段数features_per_stage, # 每个阶段的特征数conv_op, # 卷积操作类型kernel_sizes, # 卷积核大小strides, # 步幅n_blocks_per_stage, # 每个阶段的卷积层数量conv_bias, # 是否使用卷积偏置norm_op, # 是否使用归一化操作norm_op_kwargs, # 归一化操作的超参数nonlin, # 激活函数nonlin_kwargs, # 激活函数的超参数return_skips=True, # 是否返回跳跃连接stem_channels=stem_channels # 初始阶段通道数)# 构建解码器部分,`UNetResDecoder` 负责将编码器的特征图重建为最终的输出self.decoder = UNetResDecoder(self.encoder, num_classes, n_conv_per_stage_decoder, deep_supervision)def forward(self, x):# 编码器部分,提取特征并返回跳跃连接skips = self.encoder(x)# 解码器部分,基于跳跃连接和编码器特征图进行重建,返回最终的输出return self.decoder(skips)
初始化方法(
__init__
):
- 构建了一个带有编码器和解码器结构的 U-Net 风格的神经网络。
- 初始化过程中接收许多参数,这些参数用来配置卷积层的数量、尺寸、卷积核大小、步幅、激活函数、归一化操作等。
编码器部分(
ResidualMambaEncoder
):
- 使用残差连接和多阶段特征提取方法(即
ResidualMambaEncoder
),提取输入数据的多层特征图。- 编码器会生成跳跃连接,用于解码器阶段重建更细粒度的特征。
解码器部分(
UNetResDecoder
):
- 解码器通过使用跳跃连接中的特征图和编码器的输出,逐步重建出最终的输出,常用于语义分割等任务。
- 如果启用了深度监督,解码器会在不同的阶段提供额外的监督信号。
卷积层配置:
n_conv_per_stage
和n_conv_per_stage_decoder
控制每个阶段的卷积层数量,在构建网络时,会根据网络的阶段数配置相应数量的卷积层。- 通过动态计算
n_blocks_per_stage
和n_conv_per_stage_decoder
来确保不同阶段具有不同数量的卷积层。跳跃连接和深度监督:
return_skips=True
确保编码器返回跳跃连接,以便解码器使用这些特征图进行逐层重建。deep_supervision
启用时,解码器会在中间阶段输出结果,进行多尺度监督
compute_conv_feature_map_size
def compute_conv_feature_map_size(self, input_size):# 确保输入的尺寸格式正确,不包含批量和颜色/特征通道信息assert len(input_size) == convert_conv_op_to_dim(self.encoder.conv_op), "just give the image size without color/feature channels or " \"batch channel. Do not give input_size=(b, c, x, y(, z)). " \"Give input_size=(x, y(, z))!"# 调用编码器部分的 compute_conv_feature_map_size 计算卷积特征图的尺寸# 编码器的计算方式通常会影响特征图尺寸(如步幅、池化等操作)encoder_feature_map_size = self.encoder.compute_conv_feature_map_size(input_size)# 调用解码器部分的 compute_conv_feature_map_size 计算卷积特征图的尺寸# 解码器通常会通过上采样或反卷积操作逐步恢复特征图的尺寸decoder_feature_map_size = self.decoder.compute_conv_feature_map_size(input_size)# 返回编码器和解码器计算的特征图尺寸总和return encoder_feature_map_size + decoder_feature_map_size
这段代码的作用是计算给定输入尺寸下,通过编码器和解码器网络部分的卷积特征图尺寸。它首先检查输入尺寸的维度是否与编码器的卷积操作维度匹配,然后调用编码器和解码器的
compute_conv_feature_map_size
方法来计算最终的特征图尺寸。
get_umamba_enc_2d_from_plans
该函数用于构建一个名为 UMambaEnc
的深度神经网络模型。
def get_umamba_enc_2d_from_plans(plans_manager: PlansManager, # 传入的计划管理器,负责管理配置信息和训练计划dataset_json: dict, # 传入的数据集配置,以字典形式传递configuration_manager: ConfigurationManager, # 配置管理器,包含网络的超参数和其他设置num_input_channels: int, # 输入图像的通道数(如RGB图像为3通道)deep_supervision: bool = True # 是否使用深度监督,默认为True):num_stages = len(configuration_manager.conv_kernel_sizes) # 获取卷积核大小的数量,决定网络的阶段数(层数)dim = len(configuration_manager.conv_kernel_sizes[0]) # 获取第一个卷积核的维度,通常是2D或3Dconv_op = convert_dim_to_conv_op(dim) # 根据维度转换为对应的卷积操作类型(如Conv2d或Conv3d)label_manager = plans_manager.get_label_manager(dataset_json) # 获取标签管理器,管理数据集中的标签信息segmentation_network_class_name = 'UMambaEnc' # 定义网络类名,UMambaEncnetwork_class = UMambaEnc # 设置网络类为UMambaEnckwargs = { # 定义与UMambaEnc相关的超参数设置'UMambaEnc': {'input_size': configuration_manager.patch_size, # 输入图像的尺寸,从配置管理器获取'conv_bias': True, # 卷积层使用偏置项'norm_op': get_matching_instancenorm(conv_op), # 获取与卷积操作类型匹配的归一化操作(通常是InstanceNorm)'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, # 归一化操作的参数:epsilon和是否使用仿射变换'dropout_op': None, 'dropout_op_kwargs': None, # 不使用Dropout'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, # 使用LeakyReLU激活函数,inplace=True表示直接在原地修改}}# 设置每个阶段(编码器和解码器)的卷积块数量conv_or_blocks_per_stage = {'n_conv_per_stage': configuration_manager.n_conv_per_stage_encoder, # 编码器每个阶段的卷积块数'n_conv_per_stage_decoder': configuration_manager.n_conv_per_stage_decoder # 解码器每个阶段的卷积块数}# 创建UMambaEnc网络模型model = network_class(input_channels=num_input_channels, # 输入通道数n_stages=num_stages, # 网络阶段数(层数)features_per_stage=[min(configuration_manager.UNet_base_num_features * 2 ** i,configuration_manager.unet_max_num_features) for i in range(num_stages)], # 每个阶段的特征图数conv_op=conv_op, # 卷积操作类型kernel_sizes=configuration_manager.conv_kernel_sizes, # 卷积核的大小strides=configuration_manager.pool_op_kernel_sizes, # 池化操作的步幅num_classes=label_manager.num_segmentation_heads, # 分割任务的类别数量(从标签管理器获取)deep_supervision=deep_supervision, # 是否启用深度监督**conv_or_blocks_per_stage, # 解码器和编码器的卷积块数量**kwargs[segmentation_network_class_name] # 获取UMambaEnc相关的超参数并传递)model.apply(InitWeights_He(1e-2)) # 使用He初始化方法初始化网络的权重,1e-2是初始化的标准差return model # 返回创建的网络模型
这个函数的作用是根据给定的配置信息(如网络结构、卷积核尺寸、输入图像大小等)以及数据集信息,创建一个深度学习模型(
UMambaEnc
)。它适用于图像分割任务,并支持深度监督(即在多个层次上进行监督学习)。通过该函数,用户可以灵活地根据不同的配置和数据集构建出一个特定的网络架构,并初始化网络的权重。