最近看了一些关于diffusion model的资料,主要是看了b站视频2024了,diffusion为什么一直没对手?迪哥详解diffusion扩散模型直观理解、数学原理、PyTorch实现 超越GANs的范式转变!。记录一下自己的理解。$$
1、 diffusion model 做了什么事情
如下图所示, x o x_o xo是基础的无噪音的图。 x 1 , ⋯ , x t , ⋯ , x n x_1,\cdots,x_t,\cdots,x_n x1,⋯,xt,⋯,xn是一系列在 x 0 x_0 x0基础上不断加噪音的图。 z t z_t zt表示噪音,且服从正态分布 z t ∼ N ( 0 , 1 ) z_t \sim N(0,1) zt∼N(0,1)。加噪音的公式如下所示
x t = α t ⋅ x t − 1 + β t ⋅ z t , β t = 1 − α t \begin{equation} x_t=\sqrt{\alpha_t}\cdot x_{t-1}+\sqrt{\beta_t}\cdot z_t, \beta_t=1- \alpha_t \end{equation} xt=αt⋅xt−1+βt⋅zt,βt=1−αt
α t , β t \alpha_t, \beta_t αt,βt分别表示原图和噪音的比重。在论文中, β t ∈ [ 0.0001 , 0.002 , n ] \beta_t \in[0.0001, 0.002, n] βt∈[0.0001,0.002,n],可以看出 β t \beta_t βt是随着 t t t变大, β t \beta_t βt也是越来越大的,也就是说,噪音的比重越来越大了。这是由于刚开始的时候,原图比较清晰,稍微添加点噪音就比较明显。到后面的时候,图像已经变得模糊,需要添加较多的噪音才比较明显。至于为什么公式中使用 α t \sqrt{\alpha_t} αt和 β t \sqrt{\beta_t} βt,个人感觉纯是为了化简方便g
两个独立同分布的高斯分布 N ( μ 1 , σ 1 ) , N ( μ 2 , σ 2 ) N(\mu_1,\sigma_1),N(\mu_2,\sigma_2) N(μ1,σ1),N(μ2,σ2)相加后,可以得到 N ( μ 1 , σ 1 ) + N ( μ 2 , σ 2 ) ∼ N ( μ 1 + μ 2 , σ 1 2 + σ 2 2 ) N(\mu_1,\sigma_1)+N(\mu_2,\sigma_2)\sim N(\mu_1+\mu_2,\sigma_1^2+\sigma_2^2) N(μ1,σ1)+N(μ2,σ2)∼N(μ1+μ2,σ12+σ22).
在正向的计算过程中, x t x_t xt是由 x t − 1 x_{t-1} xt−1计算到的,而 x t − 1 x_{t-1} xt−1是由 x t − 2 x_{t-2} xt−2得到的,那么,能不能直接由 x 0 x_0 x0计算得到 x t x_t xt?下面推导这个过程。
x t = α t ⋅ x t − 1 + 1 − α t ⋅ z t = α t ⋅ ( α t − 1 ⋅ x t − 2 + 1 − α t − 1 ⋅ z t − 1 ) + 1 − α t ⋅ z t = α t ⋅ α t − 1 ⋅ x t − 2 + α t ⋅ 1 − α t − 1 ⋅ z t − 1 + 1 − α t ⋅ z t = α t ⋅ α t − 1 ⋅ x t − 2 + 1 − α t ⋅ α t − 1 ⋅ z = ⋮ = α t ⋅ α t − 1 ⋯ α 1 ⋅ x 0 + 1 − α t ⋅ α t − 1 ⋯ α 1 ⋅ z = α t ˉ ⋅ x 0 + 1 − α t ˉ ⋅ z ( 令 α t ˉ = α t ⋅ α t − 1 ⋯ α 1 ) \begin{equation} \begin{split} x_t &=\sqrt{\alpha_t}\cdot x_{t-1}+\sqrt{1-\alpha_t}\cdot z_t \\ &= \sqrt{\alpha_t}\cdot( \sqrt{\alpha_{t-1}}\cdot x_{t-2}+\sqrt{1-\alpha_{t-1}}\cdot z_{t-1}) +\sqrt{1-\alpha_t}\cdot z_t \\ &= \sqrt{\alpha_t \cdot \alpha_{t-1}} \cdot x_{t-2}+ \sqrt{\alpha_t} \cdot \sqrt{1-\alpha_{t-1}}\cdot z_{t-1} +\sqrt{1-\alpha_t}\cdot z_t \\ &= \sqrt{\alpha_t \cdot \alpha_{t-1}} \cdot x_{t-2}+\sqrt{1-\alpha_t \cdot \alpha_{t-1}} \cdot z \\ &=\vdots\\&= \sqrt{\alpha_t \cdot \alpha_{t-1}\cdots \alpha_{1}} \cdot x_{0} +\sqrt{1-\alpha_t \cdot \alpha_{t-1}\cdots \alpha_{1}}\cdot z\\&= \sqrt{\bar{\alpha_t}} \cdot x_{0} +\sqrt{1-\bar{\alpha_t}}\cdot z (令\bar{\alpha_t}= \alpha_t \cdot \alpha_{t-1}\cdots \alpha_{1}) \end{split} \end{equation} xt=αt⋅xt−1+1−αt⋅zt=αt⋅(αt−1⋅xt−2+1−αt−1⋅zt−1)+1−αt⋅zt=αt⋅αt−1⋅xt−2+αt⋅1−αt−1⋅zt−1+1−αt⋅zt=αt⋅αt−1⋅xt−2+1−αt⋅αt−1⋅z=⋮=αt⋅αt−1⋯α1⋅x0+1−αt⋅αt−1⋯α1⋅z=αtˉ⋅x0+1−αtˉ⋅z(令αtˉ=αt⋅αt−1⋯α1)
公式中的 z t z_t zt和 z t − 1 z_{t-1} zt−1分别表示在第 t t t和 t − 1 t-1 t−1步添加的噪音,并且这两个噪音是独立同分布的,因此在公式的第3行中,可以将后两项合并,得到第4行。合并后,仍然是高斯分布,只是均值和方差会发生改变。因此合并后的高斯分布记做 z z z。公式的后面亦是如此。
本来, x t x_t xt是一步步添加噪音得到的,从公式2可以看出, x t x_t xt可以由 x 0 x_0 x0、正态分布噪音 z z z以及每次添加的噪音比例 α t , α t − 1 ⋯ α 1 \alpha_t , \alpha_{t-1}\cdots \alpha_{1} αt,αt−1⋯α1一次求得。并且,可以看出 x t ∼ N ( α t ˉ ⋅ x 0 , 1 − α t ˉ ) x_t\sim N(\sqrt{\bar{\alpha_t}} \cdot x_{0}, 1-\bar{\alpha_t}) xt∼N(αtˉ⋅x0,1−αtˉ).
正向的加噪过程(由 x t − 1 x_{t-1} xt−1得到 x t x_t xt,记做 p ( x t ∣ x t − 1 ) p(x_t|x_{t-1}) p(xt∣xt−1))容易理解,因为这是我们自己定义的。接下来讨论反向去噪过程,即由 x t x_{t} xt得到 x t − 1 x_{t-1} xt−1记做 p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt−1∣xt)。
由贝叶斯公式,可以得到
p ( x t − 1 ∣ x t , x 0 ) = p ( x t ∣ x t − 1 , x 0 ) ⋅ p ( x t − 1 ∣ x 0 ) p ( x t ∣ x 0 ) \begin{equation} \begin{split} p(x_{t-1}|x_t,x_0) &= p(x_t|x_{t-1},x_0)\cdot \frac{p(x_{t-1}|x_0)}{p(x_t|x_0)} \end{split} \end{equation} p(xt−1∣xt,x0)=p(xt∣xt−1,x0)⋅p(xt∣x0)p(xt−1∣x0)
首先考虑公式中的 p ( x t ∣ x t − 1 , x 0 ) p(x_t|x_{t-1},x_0) p(xt∣xt−1,x0)
由公式1可知: x t = α t ⋅ x t − 1 + 1 − α t ⋅ z t x_t=\sqrt{\alpha_t}\cdot x_{t-1}+\sqrt{1- \alpha_t}\cdot z_t xt=αt⋅xt−1+1−αt⋅zt,且 x t ∼ N ( μ = α t ⋅ x t − 1 , σ 2 = 1 − α t ) x_t\sim N(\mu=\sqrt{\alpha_{t}} \cdot x_{t-1}, \sigma^2=1-\alpha_t) xt∼N(μ=αt⋅xt−1,σ2=1−αt)。因此, p ( x t ∣ x t − 1 , x 0 ) p(x_t|x_{t-1},x_0) p(xt∣xt−1,x0)可以写成下式
p ( x t ∣ x t − 1 , x 0 ) = 1 2 π σ ⋅ e − ( x t − μ ) 2 2 σ 2 (高斯分布) = 1 2 π 1 − α t ⋅ e − ( x t − α t x t − 1 ) 2 2 ⋅ ( 1 − α t ) ( 将 μ 和 σ 带入 ) \begin{equation} \begin{split} p(x_t|x_{t-1},x_0) &= \frac{1}{\sqrt{2\pi}\sigma} \cdot e^{-\frac{(x_t-\mu)^2}{2\sigma^2}}(高斯分布)\\ &= \frac{1}{\sqrt{2\pi}\sqrt{1-\alpha_t}} \cdot e^{-\frac{(x_t-\sqrt{\alpha_{t}}x_{t-1})^2}{2\cdot (1-\alpha_t)}} (将\mu和\sigma带入) \end{split} \end{equation} p(xt∣xt−1,x0)=2πσ1⋅e−2σ2(xt−μ)2(高斯分布)=2π1−αt1⋅e−2⋅(1−αt)(xt−αtxt−1)2(将μ和σ带入)
考虑公式3中的 p ( x t − 1 ∣ x 0 ) p(x_{t-1}|x_0) p(xt−1∣x0)。
由公式 x t − 1 = α t − 1 ˉ ⋅ x 0 + 1 − α t − 1 ˉ ⋅ z x_{t-1}=\sqrt{\bar{\alpha_{t-1}}} \cdot x_{0} +\sqrt{1-\bar{\alpha_{t-1}}}\cdot z xt−1=αt−1ˉ⋅x0+1−αt−1ˉ⋅z,且 x t − 1 ∼ N ( μ = α t − 1 ˉ ⋅ x 0 , σ 2 = 1 − α t − 1 ˉ ) x_{t-1}\sim N(\mu=\sqrt{\bar{\alpha_{t-1}}} \cdot x_{0},\sigma^2={1-\bar{\alpha_{t-1}}}) xt−1∼N(μ=αt−1ˉ⋅x0,σ2=1−αt−1ˉ),所以 p ( x t − 1 ∣ x 0 ) p(x_{t-1}|x_0) p(xt−1∣x0)可以写成下式
p ( x t − 1 ∣ x 0 ) = 1 2 π σ ⋅ e − ( x t − 1 − μ ) 2 2 σ 2 (高斯分布) = 1 2 π 1 − α t − 1 ˉ ⋅ e − ( x t − 1 − α t − 1 ˉ ⋅ x 0 ) 2 2 ⋅ ( 1 − α t − 1 ˉ ) ( 将 μ 和 σ 带入 ) \begin{equation} \begin{split} p(x_{t-1}|x_0) &= \frac{1}{\sqrt{2\pi}\sigma} \cdot e^{-\frac{(x_{t-1}-\mu)^2}{2\sigma^2}}(高斯分布)\\ &= \frac{1}{\sqrt{2\pi}\sqrt{{1-\bar{\alpha_{t-1}}}}} \cdot e^{-\frac{(x_{t-1}-\sqrt{\bar{\alpha_{t-1}}} \cdot x_{0})^2}{2\cdot ({1-\bar{\alpha_{t-1}}})}} (将\mu和\sigma带入) \end{split} \end{equation} p(xt−1∣x0)=2πσ1⋅e−2σ2(xt−1−μ)2(高斯分布)=2π1−αt−1ˉ1⋅e−2⋅(1−αt−1ˉ)(xt−1−αt−1ˉ⋅x0)2(将μ和σ带入)
考虑公式3中的 p ( x t ∣ x 0 ) p(x_{t}|x_0) p(xt∣x0)
由公式 x t = α t ˉ ⋅ x 0 + 1 − α t ˉ ⋅ z x_{t}=\sqrt{\bar{\alpha_{t}}} \cdot x_{0} +\sqrt{1-\bar{\alpha_{t}}}\cdot z xt=αtˉ⋅x0+1−αtˉ⋅z,且 x t ∼ N ( μ = α t ˉ ⋅ x 0 , σ 2 = 1 − α t ˉ ) x_{t}\sim N(\mu=\sqrt{\bar{\alpha_{t}}} \cdot x_{0},\sigma^2={1-\bar{\alpha_{t}}}) xt∼N(μ=αtˉ⋅x0,σ2=1−αtˉ),所以 p ( x t ∣ x 0 ) p(x_{t}|x_0) p(xt∣x0)可以写成下式
p ( x t ∣ x 0 ) = 1 2 π σ ⋅ e − ( x t − μ ) 2 2 σ 2 (高斯分布) = 1 2 π 1 − α t ˉ ⋅ e − ( x t − α t ˉ ⋅ x 0 ) 2 2 ⋅ ( 1 − α t ˉ ) ( 将 μ 和 σ 带入 ) \begin{equation} \begin{split} p(x_{t}|x_0) &= \frac{1}{\sqrt{2\pi}\sigma} \cdot e^{-\frac{(x_{t}-\mu)^2}{2\sigma^2}}(高斯分布)\\ &= \frac{1}{\sqrt{2\pi}\sqrt{{1-\bar{\alpha_{t}}}}} \cdot e^{-\frac{(x_{t}-\sqrt{\bar{\alpha_{t}}} \cdot x_{0})^2}{2\cdot ({1-\bar{\alpha_{t}}})}} (将\mu和\sigma带入) \end{split} \end{equation} p(xt∣x0)=2πσ1⋅e−2σ2(xt−μ)2(高斯分布)=2π1−αtˉ1⋅e−2⋅(1−αtˉ)(xt−αtˉ⋅x0)2(将μ和σ带入)
将公式4,5,6带入公式3中,可以得到
p ( x t ∣ x t − 1 , x 0 ) ⋅ p ( x t − 1 ∣ x 0 ) p ( x t ∣ x 0 ) = 1 2 π 1 − α t ⋅ e − ( x t − α t x t − 1 ) 2 2 ⋅ ( 1 − α t ) ⋅ 1 2 π 1 − α t − 1 ˉ ⋅ e − ( x t − 1 − α t − 1 ˉ ⋅ x 0 ) 2 2 ⋅ ( 1 − α t − 1 ˉ ) 1 2 π 1 − α t ˉ ⋅ e − ( x t − α t ˉ ⋅ x 0 ) 2 2 ⋅ ( 1 − α t ˉ ) = 1 − α t ˉ 2 π 1 − α t ⋅ 1 − α t − 1 ˉ e ( − ( x t − α t x t − 1 ) 2 2 ⋅ ( 1 − α t ) ) + ( − ( x t − 1 − α t − 1 ˉ ⋅ x 0 ) 2 2 ⋅ ( 1 − α t − 1 ˉ ) ) − ( − ( x t − α t ˉ ⋅ x 0 ) 2 2 ⋅ ( 1 − α t ˉ ) ) = 1 − α t ˉ 2 π β t ⋅ 1 − α t − 1 ˉ e − 1 2 ⋅ [ ( x t − α t x t − 1 ) 2 β t + ( x t − 1 − α t − 1 ˉ ⋅ x 0 ) 2 1 − α t − 1 ˉ − ( x t − α t ˉ ⋅ x 0 ) 2 1 − α t ˉ ] = 1 − α t ˉ 2 π β t ⋅ 1 − α t − 1 ˉ e − 1 2 ⋅ [ x t 2 − 2 ⋅ x t ⋅ x t − 1 ⋅ α t + α t ⋅ x t − 1 2 β t + x t − 1 2 − 2 ⋅ x t − 1 ⋅ x 0 ⋅ α t − 1 ˉ + α t − 1 ˉ ⋅ x 0 2 1 − α t − 1 ˉ − x t 2 − 2 ⋅ x t ⋅ x 0 ⋅ α t ˉ + α t ˉ ⋅ x 0 2 1 − α t ˉ ] = 1 − α t ˉ 2 π β t ⋅ 1 − α t − 1 ˉ e − 1 2 ⋅ [ ( α t β t + 1 1 − α t − 1 ˉ ) ⋅ x t − 1 2 − ( 2 ⋅ x t α t β t + 2 ⋅ x 0 ⋅ α t − 1 ˉ 1 − α t − 1 ˉ ) ⋅ x t − 1 + C ( x t , x 0 ) ] \begin{equation} \begin{split} p(x_t|x_{t-1},x_0)\cdot \frac{p(x_{t-1}|x_0)}{p(x_t|x_0)} &= \frac{1}{\sqrt{2\pi}\sqrt{1-\alpha_t}} \cdot e^{-\frac{(x_t-\sqrt{\alpha_{t}}x_{t-1})^2}{2\cdot (1-\alpha_t)}}\cdot \frac{\frac{1}{\sqrt{2\pi}\sqrt{{1-\bar{\alpha_{t-1}}}}} \cdot e^{-\frac{(x_{t-1}-\sqrt{\bar{\alpha_{t-1}}} \cdot x_{0})^2}{2\cdot ({1-\bar{\alpha_{t-1}}})}}}{\frac{1}{\sqrt{2\pi}\sqrt{{1-\bar{\alpha_{t}}}}} \cdot e^{-\frac{(x_{t}-\sqrt{\bar{\alpha_{t}}} \cdot x_{0})^2}{2\cdot ({1-\bar{\alpha_{t}}})}}} \\&=\frac{\sqrt{{1-\bar{\alpha_{t}}}}}{\sqrt{2\pi} \sqrt{1-\alpha_t}\cdot \sqrt{{1-\bar{\alpha_{t-1}}}}} e^{({-\frac{(x_t-\sqrt{\alpha_{t}}x_{t-1})^2}{2\cdot (1-\alpha_t)}})+({-\frac{(x_{t-1}-\sqrt{\bar{\alpha_{t-1}}} \cdot x_{0})^2}{2\cdot ({1-\bar{\alpha_{t-1}}})}})-({-\frac{(x_{t}-\sqrt{\bar{\alpha_{t}}} \cdot x_{0})^2}{2\cdot ({1-\bar{\alpha_{t}}})}})}\\ &=\frac{\sqrt{{1-\bar{\alpha_{t}}}}}{\sqrt{2\pi} \sqrt{\beta_t}\cdot \sqrt{{1-\bar{\alpha_{t-1}}}}}e^{-\frac{1}{2}\cdot[{\frac{(x_t-\sqrt{\alpha_{t}}x_{t-1})^2}{\beta_t}}+{\frac{(x_{t-1}-\sqrt{\bar{\alpha_{t-1}}} \cdot x_{0})^2}{{1-\bar{\alpha_{t-1}}}}}-{\frac{(x_{t}-\sqrt{\bar{\alpha_{t}}} \cdot x_{0})^2}{{1-\bar{\alpha_{t}}}}}]}\\&=\frac{\sqrt{{1-\bar{\alpha_{t}}}}}{\sqrt{2\pi} \sqrt{\beta_t}\cdot \sqrt{{1-\bar{\alpha_{t-1}}}}}e^{-\frac{1}{2}\cdot[\frac{x_{t}^{2}-2\cdot x_t\cdot x_{t-1}\cdot \sqrt{\alpha_t}+\alpha_t\cdot x_{t-1}^2}{\beta_t}+\frac{x_{t-1}^2-2\cdot x_{t-1}\cdot x_0\cdot \sqrt{\bar{\alpha_{t-1}}}+\bar{\alpha_{t-1}}\cdot x_0^2}{1-\bar{\alpha_{t-1}}}-\frac{x_t^2-2\cdot x_t \cdot x_0 \cdot \sqrt{\bar{\alpha_t}}+\bar{\alpha_t}\cdot x_0^2}{1-\bar{\alpha_t}}]}\\&=\frac{\sqrt{{1-\bar{\alpha_{t}}}}}{\sqrt{2\pi} \sqrt{\beta_t}\cdot \sqrt{{1-\bar{\alpha_{t-1}}}}}e^{-\frac{1}{2}\cdot [(\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha_{t-1}}})\cdot x_{t-1}^2-(\frac{2\cdot x_t \sqrt{\alpha_t}}{\beta_t}+\frac{2\cdot x_0 \cdot \sqrt{\bar{\alpha_{t-1}}}}{1-\bar{\alpha_{t-1}}})\cdot x_{t-1}+C(x_t,x_0)]} \end{split} \end{equation} p(xt∣xt−1,x0)⋅p(xt∣x0)p(xt−1∣x0)=2π1−αt1⋅e−2⋅(1−αt)(xt−αtxt−1)2⋅2π1−αtˉ1⋅e−2⋅(1−αtˉ)(xt−αtˉ⋅x0)22π1−αt−1ˉ1⋅e−2⋅(1−αt−1ˉ)(xt−1−αt−1ˉ⋅x0)2=2π1−αt⋅1−αt−1ˉ1−αtˉe(−2⋅(1−αt)(xt−αtxt−1)2)+(−2⋅(1−αt−1ˉ)(xt−1−αt−1ˉ⋅x0)2)−(−2⋅(1−αtˉ)(xt−αtˉ⋅x0)2)=2πβt⋅1−αt−1ˉ1−αtˉe−21⋅[βt(xt−αtxt−1)2+1−αt−1ˉ(xt−1−αt−1ˉ⋅x0)2−1−αtˉ(xt−αtˉ⋅x0)2]=2πβt⋅1−αt−1ˉ1−αtˉe−21⋅[βtxt2−2⋅xt⋅xt−1⋅αt+αt⋅xt−12+1−αt−1ˉxt−12−2⋅xt−1⋅x0⋅αt−1ˉ+αt−1ˉ⋅x02−1−αtˉxt2−2⋅xt⋅x0⋅αtˉ+αtˉ⋅x02]=2πβt⋅1−αt−1ˉ1−αtˉe−21⋅[(βtαt+1−αt−1ˉ1)⋅xt−12−(βt2⋅xtαt+1−αt−1ˉ2⋅x0⋅αt−1ˉ)⋅xt−1+C(xt,x0)]
考虑公式3的左侧
p ( x t − 1 ∣ x t , x 0 ) = 1 2 π σ ⋅ e − ( x t − 1 − μ ) 2 2 σ 2 (高斯分布) = 1 2 π σ ⋅ e − 1 2 [ x t − 1 2 − 2 ⋅ μ ⋅ x t − 1 + μ 2 σ 2 ] = 1 2 π σ ⋅ e − 1 2 [ 1 σ 2 ⋅ x t − 1 2 − 2 ⋅ μ σ 2 ⋅ x t − 1 + μ 2 σ 2 ] \begin{equation} \begin{split} p(x_{t-1}|x_t,x_0) &= \frac{1}{\sqrt{2\pi}\sigma} \cdot e^{-\frac{(x_{t-1}-\mu)^2}{2\sigma^2}}(高斯分布)\\ &= \frac{1}{\sqrt{2\pi}\sigma} \cdot e^{-\frac{1}{2}[\frac{x_{t-1}^2-2\cdot \mu \cdot x_{t-1} + \mu^2}{\sigma^2}]} \\ &= \frac{1}{\sqrt{2\pi}\sigma} \cdot e^{-\frac{1}{2}[\frac{1}{\sigma^2}\cdot x_{t-1}^2-\frac{2\cdot \mu }{\sigma^2}\cdot x_{t-1} + \frac{\mu^2}{\sigma^2}]} \end{split} \end{equation} p(xt−1∣xt,x0)=2πσ1⋅e−2σ2(xt−1−μ)2(高斯分布)=2πσ1⋅e−21[σ2xt−12−2⋅μ⋅xt−1+μ2]=2πσ1⋅e−21[σ21⋅xt−12−σ22⋅μ⋅xt−1+σ2μ2]
联系公式7和公式8,可以看出
1 − α t ˉ 2 π β t ⋅ 1 − α t − 1 ˉ e − 1 2 ⋅ [ ( α t β t + 1 1 − α t − 1 ˉ ) ⋅ x t − 1 2 − ( 2 ⋅ x t α t β t + 2 ⋅ x 0 ⋅ α t − 1 ˉ 1 − α t − 1 ˉ ) ⋅ x t − 1 + C ( x t , x 0 ) ] = 1 2 π σ ⋅ e − 1 2 [ 1 σ 2 ⋅ x t − 1 2 − 2 ⋅ μ σ 2 ⋅ x t − 1 + μ 2 σ 2 ] \begin{equation} \begin{split} \frac{\sqrt{{1-\bar{\alpha_{t}}}}}{\sqrt{2\pi} \sqrt{\beta_t}\cdot \sqrt{{1-\bar{\alpha_{t-1}}}}}e^{-\frac{1}{2}\cdot [(\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha_{t-1}}})\cdot x_{t-1}^2-(\frac{2\cdot x_t \sqrt{\alpha_t}}{\beta_t}+\frac{2\cdot x_0 \cdot \sqrt{\bar{\alpha_{t-1}}}}{1-\bar{\alpha_{t-1}}})\cdot x_{t-1}+C(x_t,x_0)]}= \frac{1}{\sqrt{2\pi}\sigma} \cdot e^{-\frac{1}{2}[\frac{1}{\sigma^2}\cdot x_{t-1}^2-\frac{2\cdot \mu }{\sigma^2}\cdot x_{t-1} + \frac{\mu^2}{\sigma^2}]} \end{split} \end{equation} 2πβt⋅1−αt−1ˉ1−αtˉe−21⋅[(βtαt+1−αt−1ˉ1)⋅xt−12−(βt2⋅xtαt+1−αt−1ˉ2⋅x0⋅αt−1ˉ)⋅xt−1+C(xt,x0)]=2πσ1⋅e−21[σ21⋅xt−12−σ22⋅μ⋅xt−1+σ2μ2]
观察公式9,对应项系数应该相等,可以得到
1 − α t ˉ 2 π β t ⋅ 1 − α t − 1 ˉ = 1 2 π σ ⟺ σ = β t ⋅ 1 − α t − 1 ˉ 1 − α t ˉ = β t ⋅ ( 1 − α t − 1 ˉ ) ( 1 − α t ˉ ) α t β t + 1 1 − α t − 1 ˉ = 1 σ 2 ⟺ σ = β t ⋅ ( 1 − α t − 1 ˉ ) α t ⋅ ( 1 − α t − 1 ˉ ) + β t = β t ⋅ ( 1 − α t − 1 ˉ ) ( 1 − α t ⋅ α t − 1 ˉ ) = β t ⋅ ( 1 − α t − 1 ˉ ) ( 1 − α t ˉ ) \begin{equation} \begin{split} \frac{\sqrt{{1-\bar{\alpha_{t}}}}}{\sqrt{2\pi} \sqrt{\beta_t}\cdot \sqrt{{1-\bar{\alpha_{t-1}}}}}&= \frac{1}{\sqrt{2\pi}\sigma}\iff \sigma= \frac{\sqrt{\beta_t}\cdot \sqrt{{1-\bar{\alpha_{t-1}}}}}{\sqrt{{1-\bar{\alpha_{t}}}}} =\sqrt{\frac{\beta_t\cdot (1-\bar{\alpha_{t-1}})}{(1-\bar{\alpha_{t}})}}\\ \frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha_{t-1}}} &= \frac{1}{\sigma^2} \iff \sigma=\sqrt{\frac{\beta_t\cdot (1-\bar{\alpha_{t-1}})}{\alpha_t\cdot (1-\bar{\alpha_{t-1}})+\beta^t}}=\sqrt{\frac{\beta_t\cdot (1-\bar{\alpha_{t-1}})}{(1-\alpha_t\cdot \bar{\alpha_{t-1}})}}=\sqrt{\frac{\beta_t\cdot (1-\bar{\alpha_{t-1}})}{(1-\bar{\alpha_{t}})}} \end{split} \end{equation} 2πβt⋅1−αt−1ˉ1−αtˉβtαt+1−αt−1ˉ1=2πσ1⟺σ=1−αtˉβt⋅1−αt−1ˉ=(1−αtˉ)βt⋅(1−αt−1ˉ)=σ21⟺σ=αt⋅(1−αt−1ˉ)+βtβt⋅(1−αt−1ˉ)=(1−αt⋅αt−1ˉ)βt⋅(1−αt−1ˉ)=(1−αtˉ)βt⋅(1−αt−1ˉ)
由公式10可以看出,这两种算法得到的 σ \sigma σ是相同的。接下来计算 μ \mu μ。同样观察公式9,对应项系数应该相等,可以得到
2 ⋅ μ σ 2 = ( 2 ⋅ x t α t β t + 2 ⋅ x 0 ⋅ α t − 1 ˉ 1 − α t − 1 ˉ ) ⇕ μ ⋅ ( α t β t + 1 1 − α t − 1 ˉ ) = ( α t ⋅ x t β t + x 0 ⋅ α t − 1 ˉ 1 − α t − 1 ˉ ) ⇕ μ ⋅ α t ⋅ ( 1 − α t − 1 ˉ ) + β t β t ⋅ ( 1 − α t − 1 ˉ ) = α t ⋅ ( 1 − α t − 1 ˉ ) ⋅ x t + β t ⋅ x 0 ⋅ α t − 1 ˉ β t ⋅ ( 1 − α t − 1 ˉ ) ⇕ μ = α t ⋅ ( 1 − α t − 1 ˉ ) ⋅ x t + β t ⋅ x 0 ⋅ α t − 1 ˉ α t − α t ⋅ α t − 1 ˉ + β t ⇕ μ = α t ⋅ ( 1 − α t − 1 ˉ ) ⋅ x t + β t ⋅ x 0 ⋅ α t − 1 ˉ 1 − α t ˉ ⇕ μ = α t ⋅ ( 1 − α t − 1 ˉ ) 1 − α t ˉ ⋅ x t + β t ⋅ α t − 1 ˉ 1 − α t ˉ ⋅ x 0 \begin{equation} \begin{split} \frac{2\cdot \mu }{\sigma^2}&=(\frac{2\cdot x_t \sqrt{\alpha_t}}{\beta_t}+\frac{2\cdot x_0 \cdot \sqrt{\bar{\alpha_{t-1}}}}{1-\bar{\alpha_{t-1}}})\\ &\Updownarrow \\ \mu \cdot ( \frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha_{t-1}}}) &= (\frac{\sqrt{\alpha_t}\cdot x_t}{\beta_t}+\frac{x_0 \cdot \sqrt{\bar{\alpha_{t-1}}}}{1-\bar{\alpha_{t-1}}}) \\ &\Updownarrow \\ \mu \cdot \frac{\alpha_t\cdot (1-\bar{\alpha_{t-1}})+\beta_t}{\beta_t\cdot (1-\bar{\alpha_{t-1}})}&=\frac{\sqrt{\alpha_t}\cdot(1-\bar{\alpha_{t-1}})\cdot x_t+\beta_t\cdot x_0\cdot \sqrt{\bar{\alpha_{t-1}}} }{\beta_t\cdot (1-\bar{\alpha_{t-1}})} \\ &\Updownarrow \\ \mu &= \frac{\sqrt{\alpha_t}\cdot(1-\bar{\alpha_{t-1}})\cdot x_t+\beta_t\cdot x_0\cdot \sqrt{\bar{\alpha_{t-1}}} }{\alpha_t-\alpha_t\cdot \bar{\alpha_{t-1}}+\beta_t} \\ &\Updownarrow \\ \mu&= \frac{\sqrt{\alpha_t}\cdot(1-\bar{\alpha_{t-1}})\cdot x_t+\beta_t\cdot x_0\cdot \sqrt{\bar{\alpha_{t-1}}} }{1-\bar{\alpha_t}} \\ &\Updownarrow \\ \mu&=\frac{\sqrt{\alpha_t}\cdot(1-\bar{\alpha_{t-1}})}{1-\bar{\alpha_t}}\cdot x_t+\frac{\beta_t\cdot \sqrt{\bar{\alpha_{t-1}}}}{1-\bar{\alpha_t}} \cdot x_0 \end{split} \end{equation} σ22⋅μμ⋅(βtαt+1−αt−1ˉ1)μ⋅βt⋅(1−αt−1ˉ)αt⋅(1−αt−1ˉ)+βtμμμ=(βt2⋅xtαt+1−αt−1ˉ2⋅x0⋅αt−1ˉ)⇕=(βtαt⋅xt+1−αt−1ˉx0⋅αt−1ˉ)⇕=βt⋅(1−αt−1ˉ)αt⋅(1−αt−1ˉ)⋅xt+βt⋅x0⋅αt−1ˉ⇕=αt−αt⋅αt−1ˉ+βtαt⋅(1−αt−1ˉ)⋅xt+βt⋅x0⋅αt−1ˉ⇕=1−αtˉαt⋅(1−αt−1ˉ)⋅xt+βt⋅x0⋅αt−1ˉ⇕=1−αtˉαt⋅(1−αt−1ˉ)⋅xt+1−αtˉβt⋅αt−1ˉ⋅x0
由公式10和公式11可以看出, x t − 1 x_{t-1} xt−1是一个高斯分布, x t − 1 ∼ N ( μ , σ ) x_{t-1} \sim N(\mu, \sigma) xt−1∼N(μ,σ)。其中 σ \sigma σ仅仅与权重 α \alpha α有关,而 μ \mu μ与 x 0 x_0 x0和 x t x_t xt有关。同时, x 0 x_0 x0和 x t x_t xt之间有公式1所示的关系。由公式1可以知道
x 0 = 1 α t ˉ ⋅ ( x t − 1 − α t ˉ ⋅ z t ) \begin{equation} \begin{split} x_0=\frac{1}{\sqrt{\bar{\alpha_t}}}\cdot(x_t-\sqrt{1-\bar{\alpha_t}}\cdot z_t) \end{split} \end{equation} x0=αtˉ1⋅(xt−1−αtˉ⋅zt)
将公式12带入公式11中,可以得到
μ = α t ⋅ ( 1 − α t − 1 ˉ ) 1 − α t ˉ ⋅ x t + β t ⋅ α t − 1 ˉ 1 − α t ˉ ⋅ [ 1 α t ˉ ⋅ ( x t − 1 − α t ˉ ⋅ z t ) ] = 1 α t ⋅ α t ⋅ { α t ⋅ ( 1 − α t − 1 ˉ ) 1 − α t ˉ ⋅ x t + β t ⋅ α t − 1 ˉ 1 − α t ˉ ⋅ [ 1 α t ˉ ⋅ ( x t − 1 − α t ˉ ⋅ z t ) ] } = 1 α t ⋅ { α t ⋅ ( 1 − α t − 1 ˉ ) 1 − α t ˉ ⋅ x t + β t ⋅ α t ˉ 1 − α t ˉ ⋅ [ 1 α t ˉ ⋅ ( x t − 1 − α t ˉ ⋅ z t ) ] } = 1 α t ⋅ [ α t ⋅ ( 1 − α t − 1 ˉ ) 1 − α t ˉ ⋅ x t + β t 1 − α t ˉ ⋅ ( x t − 1 − α t ˉ ⋅ z t ) ] = 1 α t ⋅ [ α t ⋅ ( 1 − α t − 1 ˉ ) + β t 1 − α t ˉ ⋅ x t − β t 1 − α t ˉ ⋅ z t ] = 1 α t ⋅ [ α t − α t ⋅ α t − 1 ˉ + β t 1 − α t ˉ ⋅ x t − β t 1 − α t ˉ ⋅ z t ] = 1 α t ⋅ [ x t − β t 1 − α t ˉ ⋅ z t ] \begin{equation} \begin{split} \mu&=\frac{\sqrt{\alpha_t}\cdot(1-\bar{\alpha_{t-1}})}{1-\bar{\alpha_t}}\cdot x_t+\frac{\beta_t\cdot \sqrt{\bar{\alpha_{t-1}}}}{1-\bar{\alpha_t}} \cdot [\frac{1}{\sqrt{\bar{\alpha_t}}}\cdot(x_t-\sqrt{1-\bar{\alpha_t}}\cdot z_t)]\\ &= \frac{1}{\sqrt{\alpha_t}}\cdot \sqrt{\alpha_t} \cdot \{ \frac{\sqrt{\alpha_t}\cdot(1-\bar{\alpha_{t-1}})}{1-\bar{\alpha_t}}\cdot x_t+\frac{\beta_t\cdot \sqrt{\bar{\alpha_{t-1}}}}{1-\bar{\alpha_t}} \cdot [\frac{1}{\sqrt{\bar{\alpha_t}}}\cdot(x_t-\sqrt{1-\bar{\alpha_t}}\cdot z_t)] \}\\ &= \frac{1}{\sqrt{\alpha_t}}\cdot \{\frac{\alpha_t\cdot (1-\bar{\alpha_{t-1}})}{1-\bar{\alpha_{t}}}\cdot x_t +\frac{\beta_t\cdot \sqrt{\bar{\alpha_t}}}{1-\bar{\alpha_t}} \cdot [\frac{1}{\sqrt{\bar{\alpha_t}}}\cdot(x_t-\sqrt{1-\bar{\alpha_t}}\cdot z_t)]\} \\ &= \frac{1}{\sqrt{\alpha_t}}\cdot [\frac{\alpha_t\cdot (1-\bar{\alpha_{t-1}})}{1-\bar{\alpha_{t}}}\cdot x_t +\frac{\beta_t}{1-\bar{\alpha_t}} \cdot (x_t-\sqrt{1-\bar{\alpha_t}}\cdot z_t)]\\ &=\frac{1}{\sqrt{\alpha_t}}\cdot [\frac{\alpha_t\cdot (1-\bar{\alpha_{t-1}})+\beta_t}{1-\bar{\alpha_t}}\cdot x_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha_t}}}\cdot z_t]\\ &=\frac{1}{\sqrt{\alpha_t}}\cdot [\frac{\alpha_t -\alpha_t \cdot \bar{\alpha_{t-1}}+\beta_t}{1-\bar{\alpha_t}}\cdot x_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha_t}}}\cdot z_t]\\ &=\frac{1}{\sqrt{\alpha_t}}\cdot [x_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha_t}}}\cdot z_t] \end{split} \end{equation} μ=1−αtˉαt⋅(1−αt−1ˉ)⋅xt+1−αtˉβt⋅αt−1ˉ⋅[αtˉ1⋅(xt−1−αtˉ⋅zt)]=αt1⋅αt⋅{1−αtˉαt⋅(1−αt−1ˉ)⋅xt+1−αtˉβt⋅αt−1ˉ⋅[αtˉ1⋅(xt−1−αtˉ⋅zt)]}=αt1⋅{1−αtˉαt⋅(1−αt−1ˉ)⋅xt+1−αtˉβt⋅αtˉ⋅[αtˉ1⋅(xt−1−αtˉ⋅zt)]}=αt1⋅[1−αtˉαt⋅(1−αt−1ˉ)⋅xt+1−αtˉβt⋅(xt−1−αtˉ⋅zt)]=αt1⋅[1−αtˉαt⋅(1−αt−1ˉ)+βt⋅xt−1−αtˉβt⋅zt]=αt1⋅[1−αtˉαt−αt⋅αt−1ˉ+βt⋅xt−1−αtˉβt⋅zt]=αt1⋅[xt−1−αtˉβt⋅zt]
由公式10和13就可以计算出由 x t x_t xt计算 x t − 1 x_{t-1} xt−1的方差和均值分别为
σ = β t ⋅ ( 1 − α t − 1 ˉ ) ( 1 − α t ˉ ) μ = 1 α t ⋅ ( x t − β t 1 − α t ˉ ⋅ z t ) \begin{equation} \begin{split} \sigma&=\sqrt{\frac{\beta_t\cdot (1-\bar{\alpha_{t-1}})}{(1-\bar{\alpha_{t}})}}\\ \mu&=\frac{1}{\sqrt{\alpha_t}}\cdot (x_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha_t}}}\cdot z_t) \end{split} \end{equation} σμ=(1−αtˉ)βt⋅(1−αt−1ˉ)=αt1⋅(xt−1−αtˉβt⋅zt)
这里的 z t z_t zt是在由 x 0 x_0 x0计算 x t x_t xt时引入的,是不可计算的,因此diffusion model的目的就是通过 x t x_t xt和 t t t预测出 z t z_t zt。
代码:
import math
from inspect import isfunction
from functools import partial
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from einops import rearrange
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torchvision.utils import save_image
from torch.optim import Adam
from pathlib import Path
from torchvision import transforms as T
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from datasets import Dataset, load_dataset, load_from_disk
import warnings
warnings.filterwarnings("ignore")def exists(x):return x is not Nonedef default(val, d):if exists(val):return valelse:return d() if isfunction(d) else dclass Residual(nn.Module):def __init__(self, fn):super().__init__()self.fn = fndef forward(self, x, *args, **kwargs):return self.fn(x, *args, **kwargs) + xdef Upsample(dim):return nn.ConvTranspose2d(dim, dim, 4, 2, 1)def Downsample(dim):return nn.Conv2d(dim, dim, 4, 2, 1)class SinusoidaPositionEmbedings(nn.Module):def __init__(self, dim):super().__init__()self.dim = dimdef forward(self, time):device = time.devicehalf_dim = self.dim // 2embeddings = math.log(10000) / (half_dim-1)embeddings = torch.exp(torch.arange(half_dim, device = device) * -embeddings)embeddings = time[:, None] * embeddings[None, :]embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)return embeddingsclass Block(nn.Module):def __init__(self, dim, dim_out, groups = 8):super().__init__()self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)self.norm = nn.GroupNorm(groups, dim_out)self.act = SiLU()def forward(self, x, scale_shift = None):x = self.proj(x)x = self.norm(x)if exists(scale_shift):scale, shift = scale_shiftx = x * (scale + 1) + shiftx = self.act(x)return xclass SiLU(nn.Module):@staticmethoddef forward(x):return x + torch.sigmoid(x)class ResnetBlock(nn.Module):def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 4):super().__init__()self.mlp = (nn.Sequential(SiLU(), nn.Linear(time_emb_dim, dim_out) if exists(time_emb_dim) else None))self.block1 = Block(dim, dim_out, groups = groups)self.block2 = Block(dim_out, dim_out, groups = groups)self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()def forward(self, x, time_emb = None):h = self.block1(x)if exists(self.mlp) and exists(time_emb):time_emb = self.mlp(time_emb)h = rearrange(time_emb, 'b c -> b c 1 1') + hh = self.block2(h)return h + self.res_conv(x)class ConvNextBlock(nn.Module):def __init__(self, dim, dim_out, *, time_emb_dim = None, mult = 2, norm = True):super.__init__()self.mlp = (nn.Sequential(nn.GELU), nn.Linear(time_emb_dim, dim)) if exists(time_emb_dim) else Noneself.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)self.net = nn.Sequential(nn.GroupNorm(1, dim) if norm else nn.Identity(),nn.Conv2d(dim, dim_out * mult, 3, padding=1),nn.GELU(),nn.GroupNorm(1, dim_out * mult),nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),)self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()def forward(self, x, time_emb = None):h = self.ds_conv(x)if exists(self.mlp) and exists(time_emb):assert exists(time_emb), 'time embedding must be passed in'condition = self.mlp(time_emb)h = h + rearrange(condition, 'b c -> b c 1 1')h = self.net(h)return h + self.res_conv(x)class Attention(nn.Module):def __init__(self, dim, heads = 4, dim_head = 32):super().__init__()self.scale = dim_head ** -0.5self.heads = headshidden_dim = dim_head * headsself.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)self.to_out = nn.Conv2d(hidden_dim, dim, 1)def forward(self, x):b, c, h, w = x.shapeqkv = self.to_qkv(x).chunk(3, dim = 1)q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)q = q * self.scalesim = einsum('b h d i, b h d j -> b h i j', q, k)attn = sim.softmax(dim = -1)out = einsum('b h i j, b h d j -> b h i d', attn, v)out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)return self.to_out(out)class LinearAttention(nn.Module):def __init__(self, dim, heads = 4, dim_head = 32):super().__init__()self.scale = dim_head ** -0.5self.heads = headshidden_dim = dim_head * headsself.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim))def forward(self, x):b, c, h, w = x.shapeqkv = self.to_qkv(x).chunk(3, dim = 1)q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)q = q.softmax(dim = -2)k = k.softmax(dim = -1)q = q * self.scalecontext = torch.einsum('b h d n, b h e n -> b h d e', k, v)out = torch.einsum('b h d e, b h d n -> b h e n', context, q)out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)return self.to_out(out)class PreNorm(nn.Module):def __init__(self, dim, fn):super().__init__()self.fn = fnself.norm = nn.GroupNorm(1, dim)def forward(self, x):x = self.norm(x)return self.fn(x)class Unet(nn.Module):def __init__(self, dim, init_dim = None, out_dim = None, dim_mults = (1, 2, 4, 8),channels = 3, with_time_emb = True, resnet_block_groups = 4, use_convnext = False, convnext_mult = 2):super().__init__()self.channels = channelsinit_dim = default(init_dim, dim // 3 * 2)self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)dims = [init_dim, *map(lambda m: dim * m, dim_mults)]in_out = list(zip(dims[:-1], dims[1:]))if use_convnext:block_klass = partial(ConvNextBlock, mult = convnext_mult)else:block_klass = partial(ResnetBlock, groups = resnet_block_groups)if with_time_emb:time_dim = dim * 4self.time_mlp = nn.Sequential(SinusoidaPositionEmbedings(dim),nn.Linear(dim, time_dim),nn.GELU(),nn.Linear(time_dim, time_dim))else:time_dim = Noneself.time_mlp = Noneself.downs = nn.ModuleList([])self.ups = nn.ModuleList([])num_resolutions = len(in_out)for ind, (dim_in, dim_out) in enumerate(in_out):is_last = ind >= (num_resolutions - 1)self.downs.append(nn.ModuleList([block_klass(dim_in, dim_out, time_emb_dim = time_dim),block_klass(dim_out, dim_out, time_emb_dim = time_dim),Residual(PreNorm(dim_out, LinearAttention(dim_out))),Downsample(dim_out) if not is_last else nn.Identity(),]))mid_dim = dims[-1]self.mim_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):is_last = ind >= (num_resolutions - 1)self.ups.append(nn.ModuleList([block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),block_klass(dim_in, dim_in, time_emb_dim=time_dim),Residual(PreNorm(dim_in, LinearAttention(dim_in))),Upsample(dim_in) if not is_last else nn.Identity(),]))out_dim = default(out_dim, channels)self.final_conv = nn.Sequential(block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1))def forward(self, x, time):x = self.init_conv(x)t = self.time_mlp(time) if exists(self.time_mlp) else Noneh = []for block1, block2, attn, downsample in self.downs:x = block1(x, t)x = block2(x, t)x = attn(x)h.append(x)x = downsample(x)x = self.mim_block1(x, t)x = self.mid_attn(x)x = self.mid_block2(x, t)for block1, block2, attn, upsample in self.ups:x = torch.cat((x, h.pop()), dim=1)x = block1(x, t)x = block2(x, t)x = attn(x)x = upsample(x)return self.final_conv(x)def consine_beta_schedule(timesteps, s = 0.008):steps = timesteps + 1x = torch.linspace(0, timesteps, steps)alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2alphas_cumprod = alphas_cumprod / alphas_cumprod[0]betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])return torch.clip(betas, 0.0001, 0.9999)def linear_beta_schedule(timesteps):beta_start = 0.0001beta_end = 0.02return torch.linspace(beta_start, beta_end, timesteps)def quadratic_beta_schedule(timesteps):beta_start = 0.0001beta_end = 0.02return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2def sigmoid_beta_schedule(timesteps):beta_start = 0.0001beta_end = 0.02betas = torch.linspace(-6, 6, timesteps)return torch.sigmoid(betas) * (beta_end - beta_start) + beta_starttimesteps = 200
betas = linear_beta_schedule(timesteps = timesteps)alphas = 1 - betas
alphas_cumprod = torch.cumprod(alphas, axis = 0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt((1. - alphas_cumprod))
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)def extract(a, t, x_shape):batch_size = t.shape[0]out = a.gather(-1, t.cpu())return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)def q_sample(x_start, t, noise = None):if noise is None:noise = torch.randn_like(x_start)sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape)return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noisedef get_noisy_image(x_start, t):x_noisy = q_sample(x_start, t=t)noisy_image = reverse_transform(x_noisy.squeeze())return noisy_imagedef p_losses(denoise_model, x_start, t, noise = None, loss_type = 'l1'):if noise is None:noise = torch.randn_like(x_start)x_noise = q_sample(x_start = x_start, t=t, noise = noise)predicted_noise = denoise_model(x_noise, t)if loss_type=='l1':loss = F.l1_loss(noise, predicted_noise)elif loss_type=='l2':loss = F.l2_loss(noise, predicted_noise)elif loss_type=='huber':loss = F.smooth_l1_loss(noise, predicted_noise)else:raise NotImplemented()return lossdef transforms(examples):# examples['image'][0].show()transform = Compose([T.RandomHorizontalFlip(),T.ToTensor(),T.Lambda(lambda t : (t * 2) - 1)])examples['pixel_values'] = [transform(image.convert('L')) for image in examples['image']]del examples['image']return examples@torch.no_grad()
def p_sample(model, x, t, t_index):betas_t = extract(betas, t, x.shape)sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x.shape)sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)model_mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t)if t_index == 0:return model_meanelse:posterior_variance_t = extract(posterior_variance, t, x.shape)noise = torch.randn_like(x)return model_mean + torch.sqrt(posterior_variance_t) * noisedef p_sample_loop(model, shape):device = next(model.parameters()).deviceb = shape[0]img = torch.randn(shape, device=device)imgs = []for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)imgs.append(img.cpu().numpy())return imgs@torch.no_grad()
def sample(model, image_size, batch_size, channels = 3):return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))def num_to_groups(num, divisor):groups = num // divisorremainder = num % divisorarr = [divisor] * groupsif remainder > 0:arr.append(remainder)return arrdef train():image_size = 28channels = 1batch_size = 128# dataset = load_dataset('fashion_mnist')# dataset.save_to_disk('/Users/mac/Desktop/DALLE2-pytorch/data')dataset = load_from_disk('/Users/mac/Desktop/DALLE2-pytorch/data')transformed_dataset = dataset.with_transform(transforms).remove_columns('label')dataload = DataLoader(transformed_dataset['train'], batch_size = batch_size, shuffle = True)device = 'cuda' if torch.cuda.is_available() else 'cpu'model = Unet(dim = image_size, channels = channels, dim_mults = (1, 2, 4))model.to(device)optimizer = Adam(model.parameters(), lr = 1e-3)result_folder = Path('/results')save_and_sample_every = 1000epochs = 1for epoch in range(epochs):for step, batch in enumerate(dataload):optimizer.zero_grad()batch_size = batch['pixel_values'].shape[0]batch = batch['pixel_values'].to(device)t = torch.randint(0, timesteps, (batch_size, ), device = device).long()loss = p_losses(model, batch, t, loss_type='huber')if step % 1 == 0:print(f'loss : {loss.item()}')loss.backward()optimizer.step()if step != 0 and step % save_and_sample_every == 0:milestone = step // save_and_sample_everybatches = num_to_groups(4, batch_size)all_images_list = list(map(lambda n: sample(model, batch_size=n, channels=channels), batches))all_images = torch.cat(all_images_list, dim=0)all_images = (all_images + 1) * 0.5save_image(all_images, str(result_folder / f'sample-{milestone}.png'), nrow = 6)samples = sample(model, image_size=image_size, batch_size=64, channels=channels)if __name__ == '__main__':train()