论文地址:https://arxiv.org/abs/2310.08528
源码地址:https://github.com/hustvl/4DGaussians
项目地址:https://guanjunwu.github.io/4dgs/
论文总结
4DGS是DeformableGS的升级版,再变形网络的技术上增加了多分辨率的编码器,提高了整体的表现;根据代码中的实际配置来看,这个方案的速度可能是个短板,不过也是一篇很不错的方案了。同时解决了多分辨率和时间两个大问题;
论文背景
动态场景渲染是计算机视觉和计算机图形学中的一项核心任务,尤其在虚拟现实、电影特效、游戏开发等领域具有广泛的应用前景。当前的方法主要有以下两类:
- 基于网格的传统方法:这些方法通常需要大量的手动操作和复杂的预处理来捕捉动态场景的几何变化。这些方法往往计算量大且难以实现实时渲染。
- 基于神经辐射场(NeRF)的方法:NeRF及其变种通过对场景进行体渲染,能够生成高质量的新视角图像。然而,NeRF在处理动态场景时面临着两个主要挑战:
- 高计算成本:动态场景需要大量的训练数据和复杂的计算,导致了高昂的计算成本。
- 低实时性:由于其复杂的计算过程,NeRF难以实现实时渲染。
论文贡献
为了克服上述挑战,论文提出了一种新的方法,即4D Gaussian Splatting(4D-GS),其主要贡献如下:
- 4D高斯点扩散框架:提出了一种新的框架,用4D高斯点(包含空间和时间信息)来表示动态场景。这种表示方法允许场景在时间维度上连续变化,从而可以实现高效的动态场景渲染。
- 多分辨率编码方法:设计了一种多分辨率编码方法,能够在不同分辨率下对动态场景进行有效编码,提高了场景的细节表现能力和计算效率。
- 实时渲染:4D-GS方法可以在复杂的动态场景下实现实时渲染,显著提高了渲染速度。具体来说,能够在高分辨率下实现每秒30到82帧的渲染速度。
方法详解
1. 4D高斯表示
4D高斯表示是一种显式的场景表示方法,它结合了3D空间信息和时间维度信息。具体来说,场景在每一个时间点用一组3D高斯分布来表示。这些3D高斯分布描述了场景中物体的空间分布和几何形状,而时间维度通过4D神经体素来捕捉变化。这样的表示方法可以高效地捕捉动态场景的几何变化和运动信息。
2. 高斯形变场网络
高斯形变场网络的核心在于通过一个小型的网络(多头高斯形变解码器)来预测高斯点在不同时间点的形变。该网络由两个主要部分组成:
- 空间-时间结构编码器:用于编码场景的空间和时间信息。通过连接邻近的3D高斯点,编码器生成丰富的3D高斯特征,这些特征能够表示场景在不同时间点的变化。
- 高斯形变解码器:解码器利用编码器生成的特征,预测3D高斯点在新时间点上的形变。这样可以高效地更新场景信息,使得高斯点在时间维度上平滑变化。
3. 差分扩散渲染
在渲染阶段,论文提出了一种新的差分扩散渲染方法。与传统的体渲染不同,该方法通过高斯点扩散将3D高斯分布直接投影到2D平面。这种方法具有以下优点:
- 高效性:由于只需要处理高斯点的投影,不需要复杂的光线追踪或体渲染,显著减少了计算开销。
- 实时性:投影过程简单且快速,适合实时渲染应用。
实验和结果
实验设置
论文进行了大量实验,使用了合成数据集和真实数据集来验证方法的有效性。实验主要包括以下几个方面:
- 合成数据集:模拟了多种动态场景,包括简单的几何形状变化和复杂的运动场景,用于验证方法在理想条件下的性能。
- 真实数据集:使用了真实的动态场景数据,如视频序列和动态捕捉数据,来测试方法在实际应用中的表现。
实验结果
实验结果表明,4D-GS在各个方面均取得了优异的性能:
- 渲染速度:在合成数据集上,4D-GS能够在800×800分辨率下实现每秒82帧的渲染速度;在真实数据集上,在1352×1014分辨率下也能实现每秒30帧的渲染速度。这表明该方法能够满足实时渲染的需求。
- 渲染质量:与其他先进方法相比,4D-GS在保持高质量图像的同时,能够更加高效地处理复杂的动态场景。其渲染效果在细节保真度和运动流畅性上均表现出色。
- 灵活性:4D-GS不仅适用于传统的动态场景渲染,还在4D场景编辑和跟踪任务中展现了很好的潜力。特别是在需要实时更新场景或处理复杂运动变化的应用中,该方法具有显著优势。
结论
综上所述,4D-GS提供了一种新的高效的动态场景渲染解决方案。通过引入4D高斯表示和差分扩散渲染方法,该方法成功克服了传统方法在处理动态场景时面临的高计算成本和低实时性问题。未来的研究可以进一步优化该方法,并探索其在更多实际应用中的潜力,如虚拟现实、实时视频处理等领域。
源码精读
新增网络结构
下面先介绍4DGS比3DGS多的1D是什么,其实就是deformableGS中的时间维度,本文中还是采用了学习方案来处理时间和空间的关系;下面的三个网络结构是按照代码中的调用顺序进行介绍。
deform_network
网络结构
代码定义了一个名为 deform_network
的神经网络模块,用于处理复杂的几何变形。这个模块就是gaussian_model
中的self._deformation
,算是整个论文的创新点,其中包含了多头高斯变形解码器和时空结构编码器;
- 初始化函数
class deform_network(nn.Module):def __init__(self, args):super(deform_network, self).__init__()# 下面的参数在argument的初始化函数中设置了net_width = args.net_widthtimebase_pe = args.timebase_pedefor_depth = args.defor_depthposbase_pe = args.posebase_pescale_rotation_pe = args.scale_rotation_peopacity_pe = args.opacity_petimenet_width = args.timenet_widthtimenet_output = args.timenet_outputgrid_pe = args.grid_petimes_ch = 2 * timebase_pe + 1 # 时间嵌入通道数self.timenet = nn.Sequential(nn.Linear(times_ch, timenet_width), nn.ReLU(),nn.Linear(timenet_width, timenet_output)) # 处理时间嵌入的网络self.deformation_net = Deformation(W=net_width, D=defor_depth,input_ch=3 + (3 * posbase_pe) * 2, grid_pe=grid_pe, input_ch_time=timenet_output, args=args)# 模型中添加多个参数,但是不作为参数更新self.register_buffer('time_poc', torch.FloatTensor([(2 ** i) for i in range(timebase_pe)]))self.register_buffer('pos_poc', torch.FloatTensor([(2 ** i) for i in range(posbase_pe)]))self.register_buffer('rotation_scaling_poc', torch.FloatTensor([(2 ** i) for i in range(scale_rotation_pe)]))self.register_buffer('opacity_poc', torch.FloatTensor([(2 ** i) for i in range(opacity_pe)]))# 对模型参数进行初始化,调用initialize_weights函数self.apply(initialize_weights)
- 动态前向传播函数
def forward_dynamic(self, point, scales=None, rotations=None, opacity=None, shs=None, times_sel=None):point_emb = poc_fre(point, self.pos_poc)scales_emb = poc_fre(scales, self.rotation_scaling_poc)rotations_emb = poc_fre(rotations, self.rotation_scaling_poc)means3D, scales, rotations, opacity, shs = self.deformation_net(point_emb, scales_emb, rotations_emb, opacity, shs, None, times_sel)return means3D, scales, rotations, opacity, shs
- 权重初始函数和频率嵌入函数
# 权重初始化,初始化方式是Glorot initialization,避免梯度消失或者爆炸
def initialize_weights(m):if isinstance(m, nn.Linear):# init.constant_(m.weight, 0)init.xavier_uniform_(m.weight, gain=1)if m.bias is not None:init.xavier_uniform_(m.weight, gain=1)# init.constant_(m.bias, 0)# 频率嵌入函数,输入数据转换为频域表示
def poc_fre(input_data, poc_buf):input_data_emb = (input_data.unsqueeze(-1) * poc_buf).flatten(-2)input_data_sin = input_data_emb.sin()input_data_cos = input_data_emb.cos()input_data_emb = torch.cat([input_data, input_data_sin, input_data_cos], -1)return input_data_emb
deform_network
网络在render
函数时进行了调用,这里输入的means3D
等变量就是三维高斯的属性;
if "coarse" in stage:means3D_final, scales_final, rotations_final, opacity_final, shs_final = means3D, scales, rotations, opacity, shselif "fine" in stage:# time0 = get_time()# means3D_deform, scales_deform, rotations_deform, opacity_deform = pc._deformation(means3D[deformation_point], scales[deformation_point], # rotations[deformation_point], opacity[deformation_point],# time[deformation_point])means3D_final, scales_final, rotations_final, opacity_final, shs_final = pc._deformation(means3D, scales,rotations, opacity, shs,time)
Deformation
网络结构
Deformation
是论文中提到的多头高斯变形解码器 D \mathcal{D} D,其实时空联合编码器是在整个类中被定义及调用的,之前提到的deform_network
是包含整个解码器的外壳,并提供了时间网络的输入;
class Deformation(nn.Module):def __init__(self, D=8, W=256, input_ch=27, input_ch_time=9, grid_pe=0, skips=[], args=None):super(Deformation, self).__init__()self.D = Dself.W = Wself.input_ch = input_chself.input_ch_time = input_ch_timeself.skips = skipsself.grid_pe = grid_peself.no_grid = args.no_grid# 时空联合编码器self.grid = HexPlaneField(args.bounds, args.kplanes_config, args.multires)# breakpoint()self.args = args# self.args.empty_voxel=True, 代码中没有调用if self.args.empty_voxel:self.empty_voxel = DenseGrid(channels=1, world_size=[64, 64, 64])if self.args.static_mlp: # 没有调用self.static_mlp = nn.Sequential(nn.ReLU(), nn.Linear(self.W, self.W), nn.ReLU(), nn.Linear(self.W, 1))self.ratio = 0self.create_net()def create_net(self):mlp_out_dim = 0if self.grid_pe != 0:grid_out_dim = self.grid.feat_dim + self.grid.feat_dim * 2else:grid_out_dim = self.grid.feat_dimif self.no_grid:self.feature_out = [nn.Linear(4, self.W)]else:self.feature_out = [nn.Linear(mlp_out_dim + grid_out_dim, self.W)]# feature_out 是论文中提到的小型MLP,用于处理for i in range(self.D - 1):self.feature_out.append(nn.ReLU())self.feature_out.append(nn.Linear(self.W, self.W))self.feature_out = nn.Sequential(*self.feature_out)# Multi-head Gaussian Deformation Decoderself.pos_deform = nn.Sequential(nn.ReLU(), nn.Linear(self.W, self.W), nn.ReLU(), nn.Linear(self.W, 3))self.scales_deform = nn.Sequential(nn.ReLU(), nn.Linear(self.W, self.W), nn.ReLU(), nn.Linear(self.W, 3))self.rotations_deform = nn.Sequential(nn.ReLU(), nn.Linear(self.W, self.W), nn.ReLU(), nn.Linear(self.W, 4))self.opacity_deform = nn.Sequential(nn.ReLU(), nn.Linear(self.W, self.W), nn.ReLU(), nn.Linear(self.W, 1))self.shs_deform = nn.Sequential(nn.ReLU(), nn.Linear(self.W, self.W), nn.ReLU(), nn.Linear(self.W, 16 * 3))def query_time(self, rays_pts_emb, scales_emb, rotations_emb, time_feature, time_emb):if self.no_grid:h = torch.cat([rays_pts_emb[:, :3], time_emb[:, :1]], -1)else:# 此处调用了HexPlaneField的前馈函数grid_feature = self.grid(rays_pts_emb[:, :3], time_emb[:, :1])# breakpoint()if self.grid_pe > 1:grid_feature = poc_fre(grid_feature, self.grid_pe)hidden = torch.cat([grid_feature], -1)hidden = self.feature_out(hidden)return hidden def forward_dynamic(self, rays_pts_emb, scales_emb, rotations_emb, opacity_emb, shs_emb, time_feature, time_emb):hidden = self.query_time(rays_pts_emb, scales_emb, rotations_emb, time_feature, time_emb)# 多头注意力解码器,用于输出多个高斯属性if self.args.static_mlp:mask = self.static_mlp(hidden)elif self.args.empty_voxel:mask = self.empty_voxel(rays_pts_emb[:, :3])else:mask = torch.ones_like(opacity_emb[:, 0]).unsqueeze(-1)if self.args.no_dx:pts = rays_pts_emb[:, :3]else:dx = self.pos_deform(hidden)pts = torch.zeros_like(rays_pts_emb[:, :3])pts = rays_pts_emb[:, :3] * mask + dxif self.args.no_ds:scales = scales_emb[:, :3]else:ds = self.scales_deform(hidden)scales = torch.zeros_like(scales_emb[:, :3])scales = scales_emb[:, :3] * mask + dsif self.args.no_dr:rotations = rotations_emb[:, :4]else:dr = self.rotations_deform(hidden)rotations = torch.zeros_like(rotations_emb[:, :4])if self.args.apply_rotation:rotations = batch_quaternion_multiply(rotations_emb, dr)else:rotations = rotations_emb[:, :4] + drif self.args.no_do:opacity = opacity_emb[:, :1]else:do = self.opacity_deform(hidden)opacity = torch.zeros_like(opacity_emb[:, :1])opacity = opacity_emb[:, :1] * mask + doif self.args.no_dshs:shs = shs_embelse:dshs = self.shs_deform(hidden).reshape([shs_emb.shape[0], 16, 3])shs = torch.zeros_like(shs_emb)# breakpoint()shs = shs_emb * mask.unsqueeze(-1) + dshsreturn pts, scales, rotations, opacity, shs
HexPlaneField
结构
这里的多分辨率结构可以看参考文献的第2篇论文,论文中有更形象的示意图以及更详细的论证;这里的多分辨率模型就是将4维参数两两组合为2维平面然后进行插值;下图可以很形象的表示4D如何两两组合成2D并进行插值映射的;
代码定义了一个 HexPlaneField
类,继承自 nn.Module
,用于处理多分辨率的空间特征场景。
- 初始化方法
__init__
# 先贴出来config的定义
self.kplanes_config = {'grid_dimensions': 2,'input_coordinate_dim': 4,'output_coordinate_dim': 32,
# [64,64,64]: resolution of spatial grid. 25: resolution of temporal grid, better to be half length of dynamic frames 'resolution': [64, 64, 64, 25] }
class HexPlaneField(nn.Module):# 输入场景边界,分辨率配置,分辨率倍数def __init__(self, bounds, planeconfig, multires) -> None:super().__init__()aabb = torch.tensor([[bounds, bounds, bounds],[-bounds, -bounds, -bounds]])self.aabb = nn.Parameter(aabb, requires_grad=False)self.grid_config = [planeconfig]self.multiscale_res_multipliers = multiresself.concat_features = True# 1. Init planesself.grids = nn.ModuleList()self.feat_dim = 0for res in self.multiscale_res_multipliers:# initialize coordinate gridconfig = self.grid_config[0].copy()# Resolution fix: multi-res only on spatial planesconfig["resolution"] = [r * res for r in config["resolution"][:3]] + config["resolution"][3:]gp = init_grid_param(grid_nd=config["grid_dimensions"],in_dim=config["input_coordinate_dim"],out_dim=config["output_coordinate_dim"],reso=config["resolution"],)# shape[1] is out-dim - Concatenate over feature len for each scaleif self.concat_features:self.feat_dim += gp[-1].shape[1]else:self.feat_dim = gp[-1].shape[1]self.grids.append(gp)# print(f"Initialized model grids: {self.grids}")print("feature_dim:", self.feat_dim)
- 获取密度函数&前馈函数
def get_density(self, pts: torch.Tensor, timestamps: Optional[torch.Tensor] = None):"""Computes and returns the densities."""# breakpoint()# 将点进行 AABB 归一化,连接时间戳,重塑为二维,pts是所有三维高斯的三维位置,timestamps是时间序列pts = normalize_aabb(pts, self.aabb)pts = torch.cat((pts, timestamps), dim=-1) # [n_rays, n_samples, 4]pts = pts.reshape(-1, pts.shape[-1])# 插值多尺度特征。features = interpolate_ms_features(pts, ms_grids=self.grids, # noqagrid_dimensions=self.grid_config[0]["grid_dimensions"],concat_features=self.concat_features, num_levels=None)if len(features) < 1:features = torch.zeros((0, 1)).to(features.device)return featuresdef forward(self,pts: torch.Tensor,timestamps: Optional[torch.Tensor] = None):features = self.get_density(pts, timestamps)return features
- 特征插值函数
def interpolate_ms_features(pts: torch.Tensor,ms_grids: Collection[Iterable[nn.Module]],grid_dimensions: int,concat_features: bool,num_levels: Optional[int],) -> torch.Tensor:coo_combs = list(itertools.combinations(range(pts.shape[-1]), grid_dimensions))if num_levels is None:num_levels = len(ms_grids)multi_scale_interp = [] if concat_features else 0.grid: nn.ParameterListfor scale_id, grid in enumerate(ms_grids[:num_levels]):interp_space = 1.for ci, coo_comb in enumerate(coo_combs):# interpolate in planefeature_dim = grid[ci].shape[1] # shape of grid[ci]: 1, out_dim, *reso# 双线性插值方法对每个平面进行插值interp_out_plane = (grid_sample_wrapper(grid[ci], pts[..., coo_comb]).view(-1, feature_dim))# compute product over planesinterp_space = interp_space * interp_out_plane# combine over scalesif concat_features:multi_scale_interp.append(interp_space)else:multi_scale_interp = multi_scale_interp + interp_spaceif concat_features:multi_scale_interp = torch.cat(multi_scale_interp, dim=-1)return multi_scale_interp
- grid初始化函数
def init_grid_param(grid_nd: int,in_dim: int,out_dim: int,reso: Sequence[int],a: float = 0.1,b: float = 0.5):assert in_dim == len(reso), "Resolution must have same number of elements as input-dimension"has_time_planes = in_dim == 4assert grid_nd <= in_dim# 坐标系索引组合coo_combs = list(itertools.combinations(range(in_dim), grid_nd))grid_coefs = nn.ParameterList()for ci, coo_comb in enumerate(coo_combs):new_grid_coef = nn.Parameter(torch.empty([1, out_dim] + [reso[cc] for cc in coo_comb[::-1]]))if has_time_planes and 3 in coo_comb: # Initialize time planes to 1nn.init.ones_(new_grid_coef)else:nn.init.uniform_(new_grid_coef, a=a, b=b)grid_coefs.append(new_grid_coef)return grid_coefs
到这里我们来分析一下self.grids
这个变量的形状,它是一个nn.ModuleList()
长度取决于multiscale_res_multipliers
,这个值再配置里面是 [ 1 , 2 , 4 , 8 ] [1,2,4,8] [1,2,4,8],因此长度为 4 4 4;每个元素nn.ParameterList
变量,也就是gp
的长度根据配置项可以知道为 6 6 6,因为是4D的两两组合;那么gp
中的维度可以根据init_grid_param
函数得到,应该为 [ 1 , out_dim , reso , reso ] [1, \text{out\_dim}, \text{reso}, \text{reso}] [1,out_dim,reso,reso],但是这里的reso应该根据不同分辨率和是否引入时间维度决定。这里给出一个当 res = 1
时配置的维度参考;从表格可以看出, ( 0 , 1 , 3 ) (0,1,3) (0,1,3)三个维度是没有时间维度的, ( 2 , 4 , 5 ) (2,4,5) (2,4,5)三个维度是包含了时间维度,这个表格再后面的loss计算上会用到;
训练过程
训练过程和原版3DGS没有太多的变化,针对4D加入了一个tv_loss
,densify和prune部分也没什么差异,下面看一下整个tv_loss
是如何计算的;
tv_loss
是一个正则项,包括三部分,平面正则,时间平滑和L1正则;这三个正则项都和变形网络的grids
相关,读者可以返回前面看一下grids
的最终输出是怎样的,可以加深对这里三个loss作用的理解;
def compute_regulation(self, time_smoothness_weight, l1_time_planes_weight, plane_tv_weight):return plane_tv_weight * self._plane_regulation() \+ time_smoothness_weight * self._time_regulation() \+ l1_time_planes_weight * self._l1_regulation()
- 平面正则:其实是对XYZ三个维度进行差分计算,实现平滑;
# 某个维度的一阶和二阶差分
def compute_plane_smoothness(t):batch_size, c, h, w = t.shape# Convolve with a second derivative filter, in the time dimension which is dimension 2first_difference = t[..., 1:, :] - t[..., :h-1, :] # [batch, c, h-1, w]second_difference = first_difference[..., 1:, :] - first_difference[..., :h-2, :] # [batch, c, h-2, w]# Take the L2 norm of the resultreturn torch.square(second_difference).mean()def _plane_regulation(self):multi_res_grids = self._deformation.deformation_net.grid.gridstotal = 0# model.grids is 6 x [1, rank * F_dim, reso, reso]for grids in multi_res_grids:if len(grids) == 3:time_grids = []else:time_grids = [0,1,3]for grid_id in time_grids:total += compute_plane_smoothness(grids[grid_id])return total
- 时间平滑:保证相邻时间的图像平滑过渡
# 对时间进行平滑
def _time_regulation(self):multi_res_grids = self._deformation.deformation_net.grid.gridstotal = 0# model.grids is 6 x [1, rank * F_dim, reso, reso]for grids in multi_res_grids:if len(grids) == 3:time_grids = []else:time_grids =[2, 4, 5]for grid_id in time_grids:total += compute_plane_smoothness(grids[grid_id])return total
- L1正则:保证图像尽量接近1;
# 时空联合维度上面进行全图的正则
def _l1_regulation(self):# model.grids is 6 x [1, rank * F_dim, reso, reso]multi_res_grids = self._deformation.deformation_net.grid.gridstotal = 0.0for grids in multi_res_grids:if len(grids) == 3:continueelse:# These are the spatiotemporal gridsspatiotemporal_grids = [2, 4, 5]for grid_id in spatiotemporal_grids:total += torch.abs(1 - grids[grid_id]).mean()return total
如果你喜欢我的文章欢迎点赞、关注,同时非常欢迎一起探讨交流相关技术,谢谢支持;