1. Reflow 方法的核心思想
Reflow 方法的核心思想是通过学习一个速度场 v(x,t)v(x, t)v(x,t),使得从初始分布 z0z_0z0 到目标分布 x1x_1x1 的路径可以通过 ODE 求解器生成。具体来说:
前向传播:通过 ODE 求解器从 z0z_0z0 生成 x1x_1x1。
损失计算:计算模型预测的速度场与目标速度场之间的差异。
2. 前向传播
在 Reflow 方法中,前向传播的核心是计算扰动数据 perturbed_data
和模型输出 score
。以下是具体步骤:
(1) 初始样本 z0z_0z0
如果启用了 Reflow 方法(
sde.reflow_flag=True
),则从输入数据batch
中提取初始样本z0
和目标数据data
:z0 = batch[0] data = batch[1] batch = data.detach().clone()
如果没有启用 Reflow 方法,则从初始分布(如高斯分布)中采样
z0
:z0 = sde.get_z0(batch).to(batch.device)
(2) 时间采样 ttt
根据
sde.reflow_t_schedule
采样时间t
:如果
sde.reflow_t_schedule == 't0'
,则固定 t=0t = 0t=0。如果
sde.reflow_t_schedule == 't1'
,则固定 t=1t = 1t=1。如果
sde.reflow_t_schedule == 'uniform'
,则从均匀分布中采样 ttt。如果
sde.reflow_t_schedule
是整数,则从离散时间点中采样 ttt。
if sde.reflow_t_schedule == 't0':t = torch.zeros(batch.shape[0], device=batch.device) * (sde.T - eps) + eps elif sde.reflow_t_schedule == 't1':t = torch.ones(batch.shape[0], device=batch.device) * (sde.T - eps) + eps elif sde.reflow_t_schedule == 'uniform':t = torch.rand(batch.shape[0], device=batch.device) * (sde.T - eps) + eps elif type(sde.reflow_t_schedule) == int:t = torch.randint(0, sde.reflow_t_schedule, (batch.shape[0],), device=batch.device) * (sde.T - eps) / sde.reflow_t_schedule + eps
(3) 扰动数据 xtx_txt
计算扰动数据
perturbed_data
:t_expand = t.view(-1, 1, 1, 1).repeat(1, batch.shape[1], batch.shape[2], batch.shape[3]) perturbed_data = t_expand * batch + (1. - t_expand) * z0
这里 xt=t⋅x1+(1−t)⋅z0x_t = t \cdot x_1 + (1 - t) \cdot z_0xt=t⋅x1+(1−t)⋅z0,其中 x1x_1x1 是目标数据,z0z_0z0 是初始样本。
(4) 模型输出 v(xt,t)v(x_t, t)v(xt,t)
通过模型计算速度场
score
:model_fn = mutils.get_model_fn(model, train=train) score = model_fn(perturbed_data, t * 999)
model_fn
是模型的前向传播函数。t * 999
是一个时间缩放因子(具体值可以根据需要调整)。
3. 损失计算
在 Reflow 方法中,损失计算的核心是计算模型预测的速度场与目标速度场之间的差异。以下是具体步骤:
(1) 目标值 target\text{target}target
计算目标值
target
:target = batch - z0
这里 target=x1−z0\text{target} = x_1 - z_0target=x1−z0,表示从 z0z_0z0 到 x1x_1x1 的方向。
(2) 损失函数
根据
sde.reflow_loss
计算损失:如果
sde.reflow_loss == 'l2'
,则使用 L2 损失:losses = torch.square(score - target)
如果
sde.reflow_loss == 'lpips'
,则使用 LPIPS 损失(需要sde.reflow_t_schedule == 't0'
):losses = sde.lpips_model(z0 + score, batch)
如果
sde.reflow_loss == 'lpips+l2'
,则同时使用 LPIPS 损失和 L2 损失:lpips_losses = sde.lpips_model(z0 + score, batch).view(batch.shape[0], 1) l2_losses = torch.square(score - target).view(batch.shape[0], -1).mean(dim=1, keepdim=True) losses = lpips_losses + l2_losses
(3) 损失聚合
根据
reduce_mean
决定是对损失取均值还是求和:losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1) loss = torch.mean(losses)
4. 总结
在使用 Reflow 方法时,前向传播和损失计算的核心逻辑如下:
前向传播:
从输入数据中提取初始样本 z0z_0z0 和目标数据 x1x_1x1。
采样时间 ttt,并计算扰动数据 xt=t⋅x1+(1−t)⋅z0x_t = t \cdot x_1 + (1 - t) \cdot z_0xt=t⋅x1+(1−t)⋅z0。
通过模型计算速度场 v(xt,t)v(x_t, t)v(xt,t)。
损失计算:
计算目标值 target=x1−z0\text{target} = x_1 - z_0target=x1−z0。
根据
sde.reflow_loss
计算损失(如 L2 损失、LPIPS 损失或两者的组合)。对损失进行聚合(取均值或求和)。
通过这种方式,Reflow 方法可以学习一个速度场,使得从初始分布 z0z_0z0 到目标分布 x1x_1x1 的路径可以通过 ODE 求解器生成。