一步扩散,分布匹配蒸馏
paper是Massachusetts Institute of Technology发表在CVPR 2024的工作
paper title:One-step Diffusion with Distribution Matching Distillation
Code:链接
图1。是哪个?在这些图像中,有些是用基线稳定扩散(SD)[63](每个2590ms)生成的,另一些则具有我们的扩散匹配蒸馏(DMD)(每个90ms)。你能说哪个是什么?我们的(从左到右):底部,顶部,底部,底部,顶部。)我们的一步文本到图像生成提供了质量可与之媲美的昂贵扩散模型。
Abstract
扩散模型可以生成高质量图像,但通常需要数十次前向传播。我们提出了分布匹配蒸馏(Distribution Matching Distillation,DMD)方法,这是一种将扩散模型转换为一步图像生成器的过程,同时尽可能减少对图像质量的影响。我们强制一步图像生成器在分布层面上匹配扩散模型,通过最小化近似 KL 散度,其梯度可以表示为两个得分函数之差,一个来自目标分布,另一个来自我们的单步生成器所产生的合成分布。这些得分函数由两个分别在每个分布上单独训练的扩散模型参数化。结合简单的回归损失来匹配多步扩散输出的全局结构,我们的方法优于所有已发表的少步扩散方法,在 ImageNet 64×64 数据集上达到了 2.62 FID,在零样本 COCO-30k 数据集上达到了 11.49 FID,性能与 Stable Diffusion 相当,但速度提升了数个数量级。利用 FP16 进行推理,我们的模型可以在现代硬件上以 20 FPS 生成图像。
1. Introduction
扩散模型 [21, 61, 63, 64, 71, 74] 已经彻底改变了图像生成领域,实现了前所未有的真实感和多样性,同时具有稳定的训练过程。然而,与 GANs [15] 和 VAEs [34] 相比,它们的采样过程较慢,是一个迭代过程,通过逐步去噪 [21, 74] 将高斯噪声样本转换为复杂的图像。这通常需要数十到数百次昂贵的神经网络计算,这限制了在创意工具中使用生成管道的交互性。
为了加速采样速度,以往的方法 [42, 43, 47, 48, 51, 65, 75, 90, 91] 通过蒸馏原始多步扩散采样过程中发现的噪声到图像的映射,将其压缩为单次前向传播的学生网络。然而,拟合如此高维度、复杂的映射无疑是一项极具挑战的任务。一个主要挑战是运行完整的去噪轨迹的高昂计算成本,仅仅是为了计算学生模型的一次损失。最近的方法通过逐步增加学生模型的采样步长,而不运行原始扩散模型的完整去噪序列来缓解这一问题 [3, 16, 42, 43, 51, 65, 75]。然而,蒸馏后的模型在性能上仍然落后于原始的多步扩散模型。
相比之下,我们并不强制建立噪声与扩散生成图像之间的对应关系,而是仅要求学生模型生成的图像在感知上与原始扩散模型生成的图像无法区分。从高层次来看,我们的目标与其他分布匹配生成模型(如 GMMN [39] 或 GANs [15])的动机相似。然而,尽管这些方法在生成逼真图像方面取得了显著成功 [27, 30],但在大规模文本到图像数据上的扩展仍然具有挑战性 [26, 62, 87]。在本研究中,我们绕过了这一问题,而是从一个已经在大规模文本到图像数据上训练过的扩散模型开始。具体而言,我们对预训练的扩散模型进行微调,使其不仅学习数据分布,还学习由我们的蒸馏生成器所产生的伪造分布。由于扩散模型被认为能够近似扩散分布上的分数函数 [23, 73],我们可以将去噪后的扩散输出解释为指向“更真实”图像的梯度方向,或者如果扩散模型是在伪造图像上训练的,则指向“更伪造”图像的梯度方向。最终,生成器的梯度更新规则被设定为二者之间的差异,将合成图像推向更高的真实度和更低的伪造度。先前的研究 [80] 提出了变分分数蒸馏(Variational Score Distillation)方法,表明使用预训练的扩散模型对真实和伪造分布进行建模在 3D 物体的测试时优化(test-time optimization)方面也同样有效。我们的洞见是,这种类似的方法实际上可以用来训练整个生成模型。
此外,我们发现预计算一定数量的多步扩散采样结果,并针对我们的单步生成施加一个简单的回归损失,在分布匹配损失的约束下,能够作为一种有效的正则化方法。此外,回归损失确保我们的单步生成器与教师模型保持一致(见图 6),展现出实时设计预览的潜力。我们的方法从 VSD [80]、GANs [15] 和 pix2pix [24] 中汲取灵感和见解,表明通过 (1) 使用扩散模型对真实和伪造分布进行建模,以及 (2) 使用简单的回归损失来匹配多步扩散输出,可以训练出一个高保真度的单步生成模型。
我们使用分布匹配蒸馏(DMD)方法训练的模型,在多个任务上进行了评估,包括 CIFAR-10 [36] 和 ImageNet 64×64 [8] 上的图像生成,以及 MS COCO 512×512 [40] 上的零样本文本到图像生成。在所有基准测试中,我们的单步生成器显著优于所有已发布的少步扩散方法,如 Progressive Distillation [51, 65]、Rectified Flow [42, 43] 和 Consistency Models [48, 75]。在 ImageNet 上,DMD 实现了 2.62 的 FID,相较于 Consistency Model [75] 提升了 2.4 倍。采用与 Stable Diffusion [63] 相同的去噪器架构,DMD 在 MS-COCO 2014 30k 上达到了 11.49 的 FID,竞争力强劲。我们的定量和定性评估表明,我们的模型生成的图像在质量上与计算成本高昂的 Stable Diffusion 生成的图像极为相似。更重要的是,我们的方法在保持图像保真度的同时,将神经网络评估次数减少了 100 倍。这种高效性使得 DMD 在使用 FP16 推理时,能够以 20 FPS 的速率生成 512×512 图像,为交互式应用带来了广阔的可能性。
2. Related Work
扩散模型扩散模型[2,21,71,74]已成为一个强大的生成建模框架,在图像生成等不同领域中取得了无与伦比的成功[61,63,64],音频合成[6,35]和视频一代[11,22,70]。这些模型通过通过反向扩散过程将噪声逐渐转化为相干结构来运行[72,74]。尽管结果是最新的结果,但扩散模型的固有迭代过程仍需要实时应用的较高且通常是高昂的计算成本。我们的工作基于领先的扩散模型[31,63],并引入了简单的蒸馏管道,将多步生成过程减少到单个正向通行证。我们的方法普遍适用于具有确定性抽样的任何扩散模型[31,72,74]。
扩散加速加速扩散模型的推理过程一直是该领域的核心关注点,促使了两类方法的发展。第一类方法致力于改进快速扩散采样器 [31, 41, 45, 46, 90],显著减少预训练扩散模型所需的采样步数——从上千步降至仅 20-50 步。然而,进一步减少步数通常会导致性能的灾难性下降。另一种方法是扩散蒸馏,它作为提升推理速度的一种有前景的途径 [3, 16, 42, 47, 51, 65, 75, 82, 91]。这些方法将扩散蒸馏建模为知识蒸馏 [19],即训练一个学生模型,将原始扩散模型的多步输出蒸馏到单步生成。Luhman 等人 [47] 以及 DSNO [92] 提出了一种简单的方法,即预计算去噪轨迹,并使用像素空间的回归损失来训练学生模型。
然而,一个重要的挑战是,在每次损失计算时运行完整的去噪轨迹成本极其昂贵。为了解决这一问题,Progressive Distillation (PD) [51, 65] 训练一系列学生模型,每个学生模型的采样步数都比前一个模型减少一半。InstaFlow [42, 43] 逐步学习更平滑的流,使得单步预测在更长距离上保持准确性。Consistency Distillation (CD) [75]、TRACT [3] 和 BOOT [16] 训练一个学生模型,使其在 ODE 流上的不同时间步匹配自身输出,并进一步约束其输出在另一个时间步上保持一致。相比之下,我们的方法表明,Luhman 等人和 DSNO 提出的简单预计算扩散输出的方法已经足够有效,只需在训练目标中引入分布匹配即可。
分布匹配近年来,一些生成模型通过恢复受预定义机制(如噪声注入 [21, 61, 64] 或 token 掩码 [5, 60, 86])损坏的样本,在扩展到复杂数据集方面取得了成功。另一方面,存在一些生成方法并不依赖样本重构作为训练目标,而是直接在分布层面对合成样本和目标样本进行匹配,例如 GMMD [10, 39] 或 GANs [15]。其中,GANs 在生成逼真图像方面展现出了前所未有的质量 [4, 26–28, 30, 67],特别是在 GAN 损失可以与特定任务的辅助回归损失结合使用以缓解训练不稳定性时,其应用范围从配对图像转换 [24, 54, 79, 89] 到无配对图像编辑 [37, 55, 94]。尽管如此,由于在大规模训练时需要精心设计网络架构以确保训练稳定性,GANs 在文本引导的合成任务中仍然不是主流选择 [26]。
最近,一些研究 [1, 12, 85] 发现了基于分数的模型与分布匹配之间的联系。特别是,ProlificDreamer [80] 提出了变分分数蒸馏(Variational Score Distillation, VSD),该方法利用预训练的文本到图像扩散模型作为分布匹配损失。由于 VSD 能够利用大规模预训练模型进行无配对学习 [17, 58],它在基于粒子的优化方法中取得了令人印象深刻的结果,特别是在文本引导的 3D 生成任务中。我们的方法对 VSD 进行了改进和扩展,使其适用于深度生成神经网络的训练,以蒸馏扩散模型。此外,受 GAN 在图像转换任务中取得成功的启发,我们引入了回归损失来增强训练的稳定性。因此,我们的方法能够在 LAION [69] 等复杂数据集上实现高质量的真实感生成。与最近将 GAN 与扩散模型结合的方法 [68, 81–83] 不同,我们的方法并不依赖于 GAN 框架。我们的方法与一些同时进行的研究 [50, 84] 共享相似的动机,即利用 VSD 目标训练生成器,但不同之处在于,我们通过引入回归损失专门优化了扩散蒸馏,并在文本到图像任务上取得了最先进的结果。
3. Distribution Matching Distillation
我们的目标是将一个给定的预训练扩散去噪器,即 基础模型(base model) μ base \mu_{\text{base}} μbase,蒸馏为一个快速的 “一步” 图像生成器 G θ G_{\theta} Gθ,使其能够生成高质量图像,而无需昂贵的迭代采样过程(参见 Sec. 3.1)。尽管我们希望从相同的分布中生成样本,但并不一定要精确复现原始映射。
类似于 GAN,我们将蒸馏模型的输出称为 fake,与训练数据分布中的 real 图像相对。我们在 图 2 中展示了我们的方法。我们通过最小化两个损失的总和来训练快速生成器:
- 分布匹配目标(distribution matching objective)(参见 Sec. 3.2),其梯度更新可以表示为两个得分函数(score function)之差;
- 回归损失(regression loss)(参见 Sec. 3.3),鼓励生成器在一个固定的噪声-图像数据集上匹配基础模型输出的大尺度结构。
至关重要的是,我们使用两个扩散去噪器分别对 真实分布(real) 和 伪造分布(fake) 的得分函数进行建模,并施加不同幅度的高斯噪声扰动。最后,在 Sec. 3.4 中,我们展示如何通过 无分类器引导(classifier-free guidance) 适配我们的训练过程。
方法概述(Method overview) 我们训练一步生成器 G θ G_{\theta} Gθ 以将随机噪声 z z z 映射到真实图像。为了匹配扩散模型的多步采样输出,我们预先计算了一组噪声-图像对,并偶尔从该集合中加载噪声,同时施加 LPIPS [88] 回归损失(regression loss) 以使一步生成器的输出与扩散模型的输出一致。此外,我们提供 分布匹配梯度(distribution matching gradient) ∇ θ D K L \nabla_{\theta} D_{KL} ∇θDKL 以增强生成的真实感。我们向伪造图像添加随机噪声,并将其输入到两个扩散模型中:一个在真实数据上预训练,另一个则持续在伪造图像上训练,采用 扩散损失(diffusion loss) 进行去噪。去噪得分(在图中可视化为均值预测)指示了如何调整图像,使其更真实或更虚假。两者之间的差异表示朝向更真实和更少伪造的方向,并将其反向传播到一步生成器。
3.1. Pretrained base model and One-step generator
我们的蒸馏过程假设给定了一个预训练的扩散模型 μ base \mu_{\text{base}} μbase。扩散模型被训练以逆转高斯扩散过程,该过程逐步向样本添加噪声,使其从真实数据分布 x 0 ∼ p real x_0 \sim p_{\text{real}} x0∼preal 变为 x T ∼ N ( 0 , I ) x_T \sim \mathcal{N}(0, I) xT∼N(0,I),经过 T T T 个时间步 [21, 71, 74],我们使用 T = 1000 T = 1000 T=1000。我们用 μ base ( x t , t ) \mu_{\text{base}}(x_t, t) μbase(xt,t) 表示扩散模型。从高斯样本 x T x_T xT 开始,模型在时间步 t ∈ { 0 , 1 , . . . , T − 1 } t \in \{0, 1, ..., T-1\} t∈{0,1,...,T−1}(或噪声级别)上迭代去噪一个运行中的噪声估计 x t x_t xt,从而生成目标数据分布的样本。扩散模型通常需要 10 到 100 个步长才能生成逼真的图像。
我们的推导使用扩散的均值预测(mean-prediction)来简化 [31],但也能在 ϵ \epsilon ϵ 预测( ϵ \epsilon ϵ-prediction)框架下工作 [21, 63],仅需变量变换 [33](详见附录 H)。我们的实现使用来自 EDM [31] 和 Stable Diffusion [63] 的预训练模型。
一步生成器(One-step generator) 我们的一步生成器 G θ G_{\theta} Gθ 具有与基础扩散去噪器相同的架构,但去除了时间条件。我们在训练前初始化其参数 θ \theta θ 为基础模型的参数,即:
G θ ( z ) = μ base ( z , T − 1 ) , ∀ z G_{\theta}(z) = \mu_{\text{base}}(z, T-1), \forall z Gθ(z)=μbase(z,T−1),∀z
3.2. Distribution Matching Loss
理想情况下,我们希望快速生成器能够生成与真实图像无法区分的样本。受 ProlificDreamer [80] 启发,我们最小化真实和伪图像分布之间的 Kullback–Leibler(KL)散度 D K L D_{KL} DKL,分别表示为 p real p_{\text{real}} preal 和 p fake p_{\text{fake}} pfake:
D K L ( p fake ∥ p real ) = E x ∼ p fake ( log ( p fakc ( x ) p real ( x ) ) ) = E z ∼ N ( 0 , I ) x = G θ ( z ) − ( log p real ( x ) − log p fakc ( x ) ) \begin{aligned} D_{K L}\left(p_{\text {fake }} \| p_{\text {real }}\right) & =\underset{x \sim p_{\text {fake }}}{\mathbb{E}}\left(\log \left(\frac{p_{\text {fakc }}(x)}{p_{\text {real }}(x)}\right)\right) \\ & =\underset{\substack{z \sim \mathcal{N}(0, I) \\ x=G_\theta(z)}}{\mathbb{E}}-\left(\log p_{\text {real }}(x)-\log p_{\text {fakc }}(x)\right) \end{aligned} DKL(pfake ∥preal )=x∼pfake E(log(preal (x)pfakc (x)))=z∼N(0,I)x=Gθ(z)E−(logpreal (x)−logpfakc (x))
(1)
计算概率密度以估计此损失通常是不可行的,但我们只需要对 θ \theta θ 计算梯度,以通过梯度下降训练生成器。
使用近似分数进行梯度更新(Gradient update using approximate scores)
对方程(1)对生成器参数求梯度:
∇ θ D K L = E z ∼ N ( 0 , I ) x = G θ ( z ) [ − ( s real ( x ) − s fake ( x ) ) d G d θ ] \nabla_\theta D_{K L}=\underset{\substack{z \sim \mathcal{N}(0, \mathbf{I}) \\ x=G_\theta(z)}}{\mathbb{E}}\left[-\left(s_{\text {real }}(x)-s_{\text {fake }}(x)\right) \frac{d G}{d \theta}\right] ∇θDKL=z∼N(0,I)x=Gθ(z)E[−(sreal (x)−sfake (x))dθdG]
(2)
其中 s real ( x ) = ∇ x log p real ( x ) s_{\text{real}}(x) = \nabla_x \log p_{\text{real}}(x) sreal(x)=∇xlogpreal(x), s fake ( x ) = ∇ x log p fake ( x ) s_{\text{fake}}(x) = \nabla_x \log p_{\text{fake}}(x) sfake(x)=∇xlogpfake(x) 分别是两个分布的分数函数。直观上, s real ( x ) s_{\text{real}}(x) sreal(x) 使 x x x 朝向 p real p_{\text{real}} preal 的模式,而 − s fake -s_{\text{fake}} −sfake 使其分散,如图 3(a, b) 所示。
图3。优化从相同配置(左)开始的各种目标会导致不同的结果。 (a)仅最大化真实分数,假样品都崩溃到了真实分布的最接近模式。 (b)有了我们的分配匹配目标而不是回归损失,生成的假数据涵盖了更多的真实分布,但仅恢复最接近的模式,完全缺少第二种模式。 (c)我们的完整目标,通过回归损失,恢复了目标分布的两种模式。
计算此梯度仍然具有挑战性,主要有两个原因:
首先,分数对于低概率的样本来说是不稳定的,特别是 p real p_{\text{real}} preal 的尾部区域;
其次,我们用于分数估计的扩散模型只能提供关于其训练分布的分数信息。Score-SDE [73, 74] 提供了一种解决这些问题的方法。
通过向数据分布加入具有不同标准差的随机高斯噪声,我们创建了一系列“模糊”的分布,这些分布在整个空间中是完全支持的,并因此相互重叠,使得方程(2)中的梯度是良好定义的(图 4)。Score-SDE 证明了经过训练的扩散模型可以近似扩散分布的分数函数。
图4。如果没有扰动,实际/假发行版可能不会重叠(a)。 Real样本只能从真实分数中获得有效的梯度,而假分数的假样本。扩散(b)后,我们的分布匹配目标到处都有明确的定义。
因此,我们的策略是使用一对扩散去噪器来建模高斯扩散后真实分布和伪分布的分数。为简化符号,我们分别定义这些分数为 s real ( x t , t ) s_{\text{real}}(x_t,t) sreal(xt,t) 和 s fake ( x t , t ) s_{\text{fake}}(x_t,t) sfake(xt,t)。扩散后的样本 x t ∼ q ( x t ∣ x ) x_t \sim q(x_t|x) xt∼q(xt∣x) 通过在生成器输出 x = G θ ( z ) x=G_{\theta}(z) x=Gθ(z) 上添加噪声获得,对应的扩散过程为:
q t ( x t ∣ x ) ∼ N ( α t x ; σ t 2 I ) , q_t(x_t|x) \sim \mathcal{N}(\alpha_t x; \sigma_t^2 I), qt(xt∣x)∼N(αtx;σt2I),
(3)
其中 α t \alpha_t αt 和 σ t \sigma_t σt 来自扩散噪声调度。
真实分数(Real score)
真实分布是固定的,对应于基础扩散模型的训练图像,因此我们使用预训练扩散模型 μ base ( x t , t ) \mu_{\text{base}}(x_t,t) μbase(xt,t) 的冻结副本来建模其分数。根据 Song 等人 [74],扩散模型的分数计算如下:
s real ( x t , t ) = − x t − α t μ base ( x t , t ) σ t 2 . s_{\text{real}}(x_t, t) = \frac{- x_t - \alpha_t \mu_{\text{base}}(x_t, t)}{\sigma_t^2}. sreal(xt,t)=σt2−xt−αtμbase(xt,t).
(4)
动态学习的伪分数(Dynamically-learned fake score)
我们按照真实分数的计算方式推导伪分数函数:
s fake ( x t , t ) = − x t − α t μ fake ϕ ( x t , t ) σ t 2 . s_{\text{fake}}(x_t, t) = \frac{- x_t - \alpha_t \mu_{\text{fake}}^{\phi}(x_t, t)}{\sigma_t^2}. sfake(xt,t)=σt2−xt−αtμfakeϕ(xt,t).
(5)
然而,由于生成样本的分布在训练过程中不断变化,我们动态调整伪扩散模型 μ fake ϕ \mu_{\text{fake}}^{\phi} μfakeϕ 以跟踪这些变化。我们从预训练扩散模型 μ base \mu_{\text{base}} μbase 初始化伪扩散模型,并在训练过程中更新参数 ϕ \phi ϕ,通过最小化标准去噪目标 [21, 77] 来进行优化:
L denoise ϕ = ∣ ∣ μ fake ϕ ( x t , t ) − x 0 ∣ ∣ 2 2 , \mathcal{L}_{\text{denoise}}^{\phi} = ||\mu_{\text{fake}}^{\phi}(x_t, t) - x_0||_2^2, Ldenoiseϕ=∣∣μfakeϕ(xt,t)−x0∣∣22,
(6)
其中 L denoise ϕ \mathcal{L}_{\text{denoise}}^{\phi} Ldenoiseϕ 按照扩散时间步 t t t 进行加权,采用与基础扩散模型训练相同的加权策略 [31, 63]。
分布匹配梯度更新(Distribution matching gradient update)
我们的最终近似分布匹配梯度是通过在方程(2)中的分数项替换为由两个扩散模型在扰动样本 x t x_t xt 上的输出所得,并在扩散时间步上计算期望得到的:
∇ θ D K L ≈ E z , t , x t , x [ w t α t ( s fake ( x t , t ) − s real ( x t , t ) ) d G θ d θ ] , \nabla_{\theta} D_{KL} \approx \mathbb{E}_{z,t,x_t,x} \left[ w_t \alpha_t (s_{\text{fake}}(x_t, t) - s_{\text{real}}(x_t, t)) \frac{d G_{\theta}}{d \theta} \right], ∇θDKL≈Ez,t,xt,x[wtαt(sfake(xt,t)−sreal(xt,t))dθdGθ],
(7)
其中 z ∼ N ( 0 , I ) z \sim \mathcal{N}(0, I) z∼N(0,I), x = G θ ( z ) x = G_{\theta}(z) x=Gθ(z), t ∼ U ( T min , T max ) t \sim \mathcal{U}(T_{\text{min}}, T_{\text{max}}) t∼U(Tmin,Tmax), x t ∼ q t ( x t ∣ x ) x_t \sim q_t(x_t | x) xt∼qt(xt∣x)。推导细节见附录 F。
这里, w t w_t wt 是一个时间相关的缩放权重项,我们在整个训练过程中动态调整其值。我们设计加权因子以在不同的噪声水平下对梯度幅度进行归一化。具体来说,我们计算去噪图像与输入之间在空间和通道维度上的均值绝对误差,并设定:
w t = σ t 2 α t C S ∣ ∣ μ base ( x t , t ) − x t ∣ ∣ , w_t = \frac{\sigma_t^2}{\alpha_t} \frac{CS}{||\mu_{\text{base}}(x_t, t) - x_t||}, wt=αtσt2∣∣μbase(xt,t)−xt∣∣CS,
(8)
其中 S S S 是空间位置的数量, C C C 是通道数量。在 Sec. 4.2 中,我们表明这种权重方案优于之前的方法 [58, 80]。我们设定 T min = 0.02 T T_{\text{min}} = 0.02T Tmin=0.02T, T max = 0.98 T T_{\text{max}} = 0.98T Tmax=0.98T,遵循 DreamFusion [58] 的做法。
3.3. Regression loss and final objective
分布匹配目标在 t ≫ 0 t \gg 0 t≫0 时是良定义的,即当生成的样本受到大量噪声污染时。然而,对于少量的噪声, s real ( x t , t ) s_{\text{real}}(x_t, t) sreal(xt,t) 往往变得不可靠,因为 p real ( x t , t ) p_{\text{real}}(x_t, t) preal(xt,t) 逐渐趋于零。此外,由于梯度 ∇ x log ( p ) \nabla_x \log(p) ∇xlog(p) 对概率密度函数 p p p 的缩放不变,优化过程可能导致模式坍缩(collapse/dropping),即伪扩散分布可能会对数据子集分配更高的整体密度。为了避免这种情况,我们引入了额外的回归损失,以确保所有模式都被保留;参见图 3(b)、©。
该损失测量了生成器与基础扩散模型输出之间的逐点距离,在相同的输入噪声下。具体而言,我们构建了一个成对数据集 D = { z , y } \mathcal{D} = \{z, y\} D={z,y},其中 z z z 是随机高斯噪声图像,而 y y y 是使用确定性 ODE 求解器 [31, 41, 72] 对预训练扩散模型 μ base \mu_{\text{base}} μbase 进行采样得到的对应输出。
在 CIFAR-10 和 ImageNet 任务中,我们采用 EDM [31] 的 Heun 求解器,其中 CIFAR-10 采用 18 步采样,ImageNet 采用 256 步采样。对于 LAION 任务,我们使用 PNDM [41],采用 50 步采样。我们发现,即使是较少量的噪声-图像对,仅占训练计算量的 1% 以下,在 CIFAR-10 任务中仍能作为有效的正则化方法。
我们的回归损失定义如下:
L reg = E ( z , y ) ∼ D ℓ ( G θ ( z ) , y ) . \mathcal{L}_{\text{reg}} = \mathbb{E}_{(z,y) \sim \mathcal{D}} \ell(G_{\theta}(z), y). Lreg=E(z,y)∼Dℓ(Gθ(z),y).
(9)
我们使用 Learned Perceptual Image Patch Similarity (LPIPS) [88] 作为距离函数 ℓ \ell ℓ,类似于 InstaFlow [43] 和 Consistency Models [75]。
最终目标(Final objective)。网络 μ fake ϕ \mu_{\text{fake}}^{\phi} μfakeϕ 通过 L denoise ϕ \mathcal{L}_{\text{denoise}}^{\phi} Ldenoiseϕ 进行训练,该损失用于计算 ∇ θ D K L \nabla_{\theta} D_{KL} ∇θDKL。在训练 G θ G_{\theta} Gθ 时,最终目标为
D K L + λ reg L reg , D_{KL} + \lambda_{\text{reg}} \mathcal{L}_{\text{reg}}, DKL+λregLreg,
其中 λ reg = 0.25 \lambda_{\text{reg}} = 0.25 λreg=0.25,除非另有说明。梯度 ∇ θ D K L \nabla_{\theta} D_{KL} ∇θDKL 由公式 (7) 计算,梯度 ∇ θ L reg \nabla_{\theta} \mathcal{L}_{\text{reg}} ∇θLreg 由公式 (9) 通过自动微分计算。
我们将这两个损失应用于不同的数据流:未配对的假样本用于计算分布匹配梯度,而配对样本用于回归损失,具体描述参见 第 3.3 节。算法 1 概述了最终的训练流程。更多详细信息请参见 附录 B。
3.4. Distillation with classifier-free guidance
无分类器引导 [20] 被广泛用于提升文本到图像扩散模型的图像质量。我们的方法同样适用于使用无分类器引导的扩散模型。我们首先通过从引导模型中采样来生成相应的噪声-输出对,以构建回归损失 L reg L_{\text{reg}} Lreg 需要的配对数据集。在计算分布匹配梯度 ∇ θ D K L \nabla_{\theta} D_{KL} ∇θDKL 时,我们用引导模型的均值预测所导出的真实分数替换原始的真实分数。同时,我们不修改假分数的公式。我们使用固定的引导尺度训练我们的单步生成器。