文章目录
- 1、Dynamic Upsampling
- 2、代码实现
paper:Learning to Upsample by Learning to Sample
Code:https://github.com/tiny-smart/dysample
1、Dynamic Upsampling
论文指出在现有的研究中的一些弊端:现有的上采样器虽然性能出色,但引入了大量的计算负担,这些主要是由于动态卷积的计算成本和生成动态核的额外子网络。所以,这篇论文提出一种动态上采样(Dynamic Upsampling),是一种轻量级且有效的动态上采样,Dynamic Upsampling是为了解决现有基于核的动态上采样器的局限性。
在这篇论文中,DySample 将上采样过程视为点采样,从而绕过了基于核的范式。其核心思想是将输入特征插值到连续的空间,并生成内容感知的采样点来重新采样连续图。
对于一个输入X,DySample 的实现过程:
- 特征插值: 首先使用双线性插值将输入特征图 X 插值到连续的空间,得到插值后的特征图 X’。
- 采样点生成: 采样点生成器负责生成内容感知的采样点集 S,用于对插值后的特征图 X’ 进行重采样。
- 偏移量生成: 使用线性层或线性层 + Pixel Shuffle 的方式生成偏移量 O。
- 采样点集生成: 将偏移量 O 与原始采样网格 G 相加,得到最终的采样点集 S。
- 网格采样: 最后使用 grid_sample 函数,根据生成的采样点集 S 对插值后的特征图 X’ 进行重采样,得到最终的上采样特征图 X’'。
论文为了使 DySample 能适用于更多任务中,对DySample做出了几种改进:
- DySample: 使用静态范围因子和线性层 + Pixel Shuffle 的方式生成偏移量。
- DySample+: 使用动态范围因子和线性层 + Pixel Shuffle 的方式生成偏移量。
- DySample-S: 使用静态范围因子和 Pixel Shuffle + 线性层的方式生成偏移量。
- DySample-S+: 使用动态范围因子和 Pixel Shuffle + 线性层的方式生成偏移量。
Dynamic Upsampling 结构图:
2、代码实现
import torch
import torch.nn as nn
import torch.nn.functional as Fdef normal_init(module, mean=0, std=1, bias=0):if hasattr(module, 'weight') and module.weight is not None:nn.init.normal_(module.weight, mean, std)if hasattr(module, 'bias') and module.bias is not None:nn.init.constant_(module.bias, bias)def constant_init(module, val, bias=0):if hasattr(module, 'weight') and module.weight is not None:nn.init.constant_(module.weight, val)if hasattr(module, 'bias') and module.bias is not None:nn.init.constant_(module.bias, bias)class DySample(nn.Module):def __init__(self, in_channels, scale=2, style='pl', groups=4, dyscope=False):super().__init__()self.scale = scaleself.style = styleself.groups = groupsassert style in ['lp', 'pl']if style == 'pl':assert in_channels >= scale ** 2 and in_channels % scale ** 2 == 0assert in_channels >= groups and in_channels % groups == 0if style == 'pl':in_channels = in_channels // scale ** 2out_channels = 2 * groupselse:out_channels = 2 * groups * scale ** 2self.offset = nn.Conv2d(in_channels, out_channels, 1)normal_init(self.offset, std=0.001)if dyscope:self.scope = nn.Conv2d(in_channels, out_channels, 1, bias=False)constant_init(self.scope, val=0.)self.register_buffer('init_pos', self._init_pos())def _init_pos(self):h = torch.arange((-self.scale + 1) / 2, (self.scale - 1) / 2 + 1) / self.scalereturn torch.stack(torch.meshgrid([h, h])).transpose(1, 2).repeat(1, self.groups, 1).reshape(1, -1, 1, 1)def sample(self, x, offset):B, _, H, W = offset.shapeoffset = offset.view(B, 2, -1, H, W)coords_h = torch.arange(H) + 0.5coords_w = torch.arange(W) + 0.5coords = torch.stack(torch.meshgrid([coords_w, coords_h])).transpose(1, 2).unsqueeze(1).unsqueeze(0).type(x.dtype).to(x.device)normalizer = torch.tensor([W, H], dtype=x.dtype, device=x.device).view(1, 2, 1, 1, 1)coords = 2 * (coords + offset) / normalizer - 1coords = F.pixel_shuffle(coords.view(B, -1, H, W), self.scale).view(B, 2, -1, self.scale * H, self.scale * W).permute(0, 2, 3, 4, 1).contiguous().flatten(0, 1)return F.grid_sample(x.reshape(B * self.groups, -1, H, W), coords, mode='bilinear',align_corners=False, padding_mode="border").view(B, -1, self.scale * H, self.scale * W)def forward_lp(self, x):if hasattr(self, 'scope'):offset = self.offset(x) * self.scope(x).sigmoid() * 0.5 + self.init_poselse:offset = self.offset(x) * 0.25 + self.init_posreturn self.sample(x, offset)def forward_pl(self, x):x_ = F.pixel_shuffle(x, self.scale)if hasattr(self, 'scope'):offset = F.pixel_unshuffle(self.offset(x_) * self.scope(x_).sigmoid(), self.scale) * 0.5 + self.init_poselse:offset = F.pixel_unshuffle(self.offset(x_), self.scale) * 0.25 + self.init_posreturn self.sample(x, offset)def forward(self, x):if self.style == 'pl':return self.forward_pl(x)return self.forward_lp(x)if __name__ == '__main__':"""input : (B, C, H, W)output : (B, C, H * 2, W * 2)"""x = torch.randn(4, 512, 7, 7).cuda()model = DySample(512).cuda()out = model(x)print(out.shape)