您的位置:首页 > 房产 > 建筑 > 衢州网站设计排名_龙岩做网站开发大概价格_免费html网页模板_公司全网推广

衢州网站设计排名_龙岩做网站开发大概价格_免费html网页模板_公司全网推广

2025/1/7 4:02:31 来源:https://blog.csdn.net/wei582636312/article/details/144332833  浏览:    关键词:衢州网站设计排名_龙岩做网站开发大概价格_免费html网页模板_公司全网推广
衢州网站设计排名_龙岩做网站开发大概价格_免费html网页模板_公司全网推广

在这里插入图片描述

文章目录

  • 1、Dynamic Upsampling
  • 2、代码实现

paper:Learning to Upsample by Learning to Sample

Code:https://github.com/tiny-smart/dysample


1、Dynamic Upsampling

论文指出在现有的研究中的一些弊端:现有的上采样器虽然性能出色,但引入了大量的计算负担,这些主要是由于动态卷积的计算成本和生成动态核的额外子网络。所以,这篇论文提出一种动态上采样(Dynamic Upsampling),是一种轻量级且有效的动态上采样,Dynamic Upsampling是为了解决现有基于核的动态上采样器的局限性。

在这篇论文中,DySample 将上采样过程视为点采样,从而绕过了基于核的范式。其核心思想是将输入特征插值到连续的空间,并生成内容感知的采样点来重新采样连续图。

对于一个输入X,DySample 的实现过程:

  1. 特征插值: 首先使用双线性插值将输入特征图 X 插值到连续的空间,得到插值后的特征图 X’。
  2. 采样点生成: 采样点生成器负责生成内容感知的采样点集 S,用于对插值后的特征图 X’ 进行重采样。
  3. 偏移量生成: 使用线性层或线性层 + Pixel Shuffle 的方式生成偏移量 O。
  4. 采样点集生成: 将偏移量 O 与原始采样网格 G 相加,得到最终的采样点集 S。
  5. 网格采样: 最后使用 grid_sample 函数,根据生成的采样点集 S 对插值后的特征图 X’ 进行重采样,得到最终的上采样特征图 X’'。

论文为了使 DySample 能适用于更多任务中,对DySample做出了几种改进:

  • DySample: 使用静态范围因子和线性层 + Pixel Shuffle 的方式生成偏移量。
  • DySample+: 使用动态范围因子和线性层 + Pixel Shuffle 的方式生成偏移量。
  • DySample-S: 使用静态范围因子和 Pixel Shuffle + 线性层的方式生成偏移量。
  • DySample-S+: 使用动态范围因子和 Pixel Shuffle + 线性层的方式生成偏移量。

Dynamic Upsampling 结构图:
在这里插入图片描述


2、代码实现

import torch
import torch.nn as nn
import torch.nn.functional as Fdef normal_init(module, mean=0, std=1, bias=0):if hasattr(module, 'weight') and module.weight is not None:nn.init.normal_(module.weight, mean, std)if hasattr(module, 'bias') and module.bias is not None:nn.init.constant_(module.bias, bias)def constant_init(module, val, bias=0):if hasattr(module, 'weight') and module.weight is not None:nn.init.constant_(module.weight, val)if hasattr(module, 'bias') and module.bias is not None:nn.init.constant_(module.bias, bias)class DySample(nn.Module):def __init__(self, in_channels, scale=2, style='pl', groups=4, dyscope=False):super().__init__()self.scale = scaleself.style = styleself.groups = groupsassert style in ['lp', 'pl']if style == 'pl':assert in_channels >= scale ** 2 and in_channels % scale ** 2 == 0assert in_channels >= groups and in_channels % groups == 0if style == 'pl':in_channels = in_channels // scale ** 2out_channels = 2 * groupselse:out_channels = 2 * groups * scale ** 2self.offset = nn.Conv2d(in_channels, out_channels, 1)normal_init(self.offset, std=0.001)if dyscope:self.scope = nn.Conv2d(in_channels, out_channels, 1, bias=False)constant_init(self.scope, val=0.)self.register_buffer('init_pos', self._init_pos())def _init_pos(self):h = torch.arange((-self.scale + 1) / 2, (self.scale - 1) / 2 + 1) / self.scalereturn torch.stack(torch.meshgrid([h, h])).transpose(1, 2).repeat(1, self.groups, 1).reshape(1, -1, 1, 1)def sample(self, x, offset):B, _, H, W = offset.shapeoffset = offset.view(B, 2, -1, H, W)coords_h = torch.arange(H) + 0.5coords_w = torch.arange(W) + 0.5coords = torch.stack(torch.meshgrid([coords_w, coords_h])).transpose(1, 2).unsqueeze(1).unsqueeze(0).type(x.dtype).to(x.device)normalizer = torch.tensor([W, H], dtype=x.dtype, device=x.device).view(1, 2, 1, 1, 1)coords = 2 * (coords + offset) / normalizer - 1coords = F.pixel_shuffle(coords.view(B, -1, H, W), self.scale).view(B, 2, -1, self.scale * H, self.scale * W).permute(0, 2, 3, 4, 1).contiguous().flatten(0, 1)return F.grid_sample(x.reshape(B * self.groups, -1, H, W), coords, mode='bilinear',align_corners=False, padding_mode="border").view(B, -1, self.scale * H, self.scale * W)def forward_lp(self, x):if hasattr(self, 'scope'):offset = self.offset(x) * self.scope(x).sigmoid() * 0.5 + self.init_poselse:offset = self.offset(x) * 0.25 + self.init_posreturn self.sample(x, offset)def forward_pl(self, x):x_ = F.pixel_shuffle(x, self.scale)if hasattr(self, 'scope'):offset = F.pixel_unshuffle(self.offset(x_) * self.scope(x_).sigmoid(), self.scale) * 0.5 + self.init_poselse:offset = F.pixel_unshuffle(self.offset(x_), self.scale) * 0.25 + self.init_posreturn self.sample(x, offset)def forward(self, x):if self.style == 'pl':return self.forward_pl(x)return self.forward_lp(x)if __name__ == '__main__':"""input : (B, C, H, W)output : (B, C, H * 2, W * 2)"""x = torch.randn(4, 512, 7, 7).cuda()model = DySample(512).cuda()out = model(x)print(out.shape)

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com