您的位置:首页 > 娱乐 > 八卦 > 企业信用信息网公示网官网查询_ui网页设计培训学校_网络营销案例分析ppt_秦皇岛百度推广

企业信用信息网公示网官网查询_ui网页设计培训学校_网络营销案例分析ppt_秦皇岛百度推广

2025/1/7 19:09:25 来源:https://blog.csdn.net/u010948546/article/details/144902864  浏览:    关键词:企业信用信息网公示网官网查询_ui网页设计培训学校_网络营销案例分析ppt_秦皇岛百度推广
企业信用信息网公示网官网查询_ui网页设计培训学校_网络营销案例分析ppt_秦皇岛百度推广

最近看了一些关于diffusion model的资料,主要是看了b站视频2024了,diffusion为什么一直没对手?迪哥详解diffusion扩散模型直观理解、数学原理、PyTorch实现 超越GANs的范式转变!。记录一下自己的理解。$$

1、 diffusion model 做了什么事情

如下图所示, x o x_o xo是基础的无噪音的图。 x 1 , ⋯ , x t , ⋯ , x n x_1,\cdots,x_t,\cdots,x_n x1,,xt,,xn是一系列在 x 0 x_0 x0基础上不断加噪音的图。 z t z_t zt表示噪音,且服从正态分布 z t ∼ N ( 0 , 1 ) z_t \sim N(0,1) ztN(0,1)。加噪音的公式如下所示
x t = α t ⋅ x t − 1 + β t ⋅ z t , β t = 1 − α t \begin{equation} x_t=\sqrt{\alpha_t}\cdot x_{t-1}+\sqrt{\beta_t}\cdot z_t, \beta_t=1- \alpha_t \end{equation} xt=αt xt1+βt zt,βt=1αt

α t , β t \alpha_t, \beta_t αt,βt分别表示原图和噪音的比重。在论文中, β t ∈ [ 0.0001 , 0.002 , n ] \beta_t \in[0.0001, 0.002, n] βt[0.0001,0.002,n],可以看出 β t \beta_t βt是随着 t t t变大, β t \beta_t βt也是越来越大的,也就是说,噪音的比重越来越大了。这是由于刚开始的时候,原图比较清晰,稍微添加点噪音就比较明显。到后面的时候,图像已经变得模糊,需要添加较多的噪音才比较明显。至于为什么公式中使用 α t \sqrt{\alpha_t} αt β t \sqrt{\beta_t} βt ,个人感觉纯是为了化简方便g请添加图片描述
两个独立同分布的高斯分布 N ( μ 1 , σ 1 ) , N ( μ 2 , σ 2 ) N(\mu_1,\sigma_1),N(\mu_2,\sigma_2) N(μ1,σ1),N(μ2,σ2)相加后,可以得到 N ( μ 1 , σ 1 ) + N ( μ 2 , σ 2 ) ∼ N ( μ 1 + μ 2 , σ 1 2 + σ 2 2 ) N(\mu_1,\sigma_1)+N(\mu_2,\sigma_2)\sim N(\mu_1+\mu_2,\sigma_1^2+\sigma_2^2) N(μ1,σ1)+N(μ2,σ2)N(μ1+μ2,σ12+σ22).

在正向的计算过程中, x t x_t xt是由 x t − 1 x_{t-1} xt1计算到的,而 x t − 1 x_{t-1} xt1是由 x t − 2 x_{t-2} xt2得到的,那么,能不能直接由 x 0 x_0 x0计算得到 x t x_t xt?下面推导这个过程。

x t = α t ⋅ x t − 1 + 1 − α t ⋅ z t = α t ⋅ ( α t − 1 ⋅ x t − 2 + 1 − α t − 1 ⋅ z t − 1 ) + 1 − α t ⋅ z t = α t ⋅ α t − 1 ⋅ x t − 2 + α t ⋅ 1 − α t − 1 ⋅ z t − 1 + 1 − α t ⋅ z t = α t ⋅ α t − 1 ⋅ x t − 2 + 1 − α t ⋅ α t − 1 ⋅ z = ⋮ = α t ⋅ α t − 1 ⋯ α 1 ⋅ x 0 + 1 − α t ⋅ α t − 1 ⋯ α 1 ⋅ z = α t ˉ ⋅ x 0 + 1 − α t ˉ ⋅ z ( 令 α t ˉ = α t ⋅ α t − 1 ⋯ α 1 ) \begin{equation} \begin{split} x_t &=\sqrt{\alpha_t}\cdot x_{t-1}+\sqrt{1-\alpha_t}\cdot z_t \\ &= \sqrt{\alpha_t}\cdot( \sqrt{\alpha_{t-1}}\cdot x_{t-2}+\sqrt{1-\alpha_{t-1}}\cdot z_{t-1}) +\sqrt{1-\alpha_t}\cdot z_t \\ &= \sqrt{\alpha_t \cdot \alpha_{t-1}} \cdot x_{t-2}+ \sqrt{\alpha_t} \cdot \sqrt{1-\alpha_{t-1}}\cdot z_{t-1} +\sqrt{1-\alpha_t}\cdot z_t \\ &= \sqrt{\alpha_t \cdot \alpha_{t-1}} \cdot x_{t-2}+\sqrt{1-\alpha_t \cdot \alpha_{t-1}} \cdot z \\ &=\vdots\\&= \sqrt{\alpha_t \cdot \alpha_{t-1}\cdots \alpha_{1}} \cdot x_{0} +\sqrt{1-\alpha_t \cdot \alpha_{t-1}\cdots \alpha_{1}}\cdot z\\&= \sqrt{\bar{\alpha_t}} \cdot x_{0} +\sqrt{1-\bar{\alpha_t}}\cdot z (令\bar{\alpha_t}= \alpha_t \cdot \alpha_{t-1}\cdots \alpha_{1}) \end{split} \end{equation} xt=αt xt1+1αt zt=αt (αt1 xt2+1αt1 zt1)+1αt zt=αtαt1 xt2+αt 1αt1 zt1+1αt zt=αtαt1 xt2+1αtαt1 z==αtαt1α1 x0+1αtαt1α1 z=αtˉ x0+1αtˉ z(αtˉ=αtαt1α1)

公式中的 z t z_t zt z t − 1 z_{t-1} zt1分别表示在第 t t t t − 1 t-1 t1步添加的噪音,并且这两个噪音是独立同分布的,因此在公式的第3行中,可以将后两项合并,得到第4行。合并后,仍然是高斯分布,只是均值和方差会发生改变。因此合并后的高斯分布记做 z z z。公式的后面亦是如此。

本来, x t x_t xt是一步步添加噪音得到的,从公式2可以看出, x t x_t xt可以由 x 0 x_0 x0、正态分布噪音 z z z以及每次添加的噪音比例 α t , α t − 1 ⋯ α 1 \alpha_t , \alpha_{t-1}\cdots \alpha_{1} αt,αt1α1一次求得。并且,可以看出 x t ∼ N ( α t ˉ ⋅ x 0 , 1 − α t ˉ ) x_t\sim N(\sqrt{\bar{\alpha_t}} \cdot x_{0}, 1-\bar{\alpha_t}) xtN(αtˉ x0,1αtˉ).

正向的加噪过程(由 x t − 1 x_{t-1} xt1得到 x t x_t xt,记做 p ( x t ∣ x t − 1 ) p(x_t|x_{t-1}) p(xtxt1))容易理解,因为这是我们自己定义的。接下来讨论反向去噪过程,即由 x t x_{t} xt得到 x t − 1 x_{t-1} xt1记做 p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt1xt)

由贝叶斯公式,可以得到

p ( x t − 1 ∣ x t , x 0 ) = p ( x t ∣ x t − 1 , x 0 ) ⋅ p ( x t − 1 ∣ x 0 ) p ( x t ∣ x 0 ) \begin{equation} \begin{split} p(x_{t-1}|x_t,x_0) &= p(x_t|x_{t-1},x_0)\cdot \frac{p(x_{t-1}|x_0)}{p(x_t|x_0)} \end{split} \end{equation} p(xt1xt,x0)=p(xtxt1,x0)p(xtx0)p(xt1x0)

首先考虑公式中的 p ( x t ∣ x t − 1 , x 0 ) p(x_t|x_{t-1},x_0) p(xtxt1,x0)

由公式1可知: x t = α t ⋅ x t − 1 + 1 − α t ⋅ z t x_t=\sqrt{\alpha_t}\cdot x_{t-1}+\sqrt{1- \alpha_t}\cdot z_t xt=αt xt1+1αt zt,且 x t ∼ N ( μ = α t ⋅ x t − 1 , σ 2 = 1 − α t ) x_t\sim N(\mu=\sqrt{\alpha_{t}} \cdot x_{t-1}, \sigma^2=1-\alpha_t) xtN(μ=αt xt1,σ2=1αt)。因此, p ( x t ∣ x t − 1 , x 0 ) p(x_t|x_{t-1},x_0) p(xtxt1,x0)可以写成下式
p ( x t ∣ x t − 1 , x 0 ) = 1 2 π σ ⋅ e − ( x t − μ ) 2 2 σ 2 (高斯分布) = 1 2 π 1 − α t ⋅ e − ( x t − α t x t − 1 ) 2 2 ⋅ ( 1 − α t ) ( 将 μ 和 σ 带入 ) \begin{equation} \begin{split} p(x_t|x_{t-1},x_0) &= \frac{1}{\sqrt{2\pi}\sigma} \cdot e^{-\frac{(x_t-\mu)^2}{2\sigma^2}}(高斯分布)\\ &= \frac{1}{\sqrt{2\pi}\sqrt{1-\alpha_t}} \cdot e^{-\frac{(x_t-\sqrt{\alpha_{t}}x_{t-1})^2}{2\cdot (1-\alpha_t)}} (将\mu和\sigma带入) \end{split} \end{equation} p(xtxt1,x0)=2π σ1e2σ2(xtμ)2(高斯分布)=2π 1αt 1e2(1αt)(xtαt xt1)2(μσ带入)

考虑公式3中的 p ( x t − 1 ∣ x 0 ) p(x_{t-1}|x_0) p(xt1x0)

由公式 x t − 1 = α t − 1 ˉ ⋅ x 0 + 1 − α t − 1 ˉ ⋅ z x_{t-1}=\sqrt{\bar{\alpha_{t-1}}} \cdot x_{0} +\sqrt{1-\bar{\alpha_{t-1}}}\cdot z xt1=αt1ˉ x0+1αt1ˉ z,且 x t − 1 ∼ N ( μ = α t − 1 ˉ ⋅ x 0 , σ 2 = 1 − α t − 1 ˉ ) x_{t-1}\sim N(\mu=\sqrt{\bar{\alpha_{t-1}}} \cdot x_{0},\sigma^2={1-\bar{\alpha_{t-1}}}) xt1N(μ=αt1ˉ x0,σ2=1αt1ˉ),所以 p ( x t − 1 ∣ x 0 ) p(x_{t-1}|x_0) p(xt1x0)可以写成下式
p ( x t − 1 ∣ x 0 ) = 1 2 π σ ⋅ e − ( x t − 1 − μ ) 2 2 σ 2 (高斯分布) = 1 2 π 1 − α t − 1 ˉ ⋅ e − ( x t − 1 − α t − 1 ˉ ⋅ x 0 ) 2 2 ⋅ ( 1 − α t − 1 ˉ ) ( 将 μ 和 σ 带入 ) \begin{equation} \begin{split} p(x_{t-1}|x_0) &= \frac{1}{\sqrt{2\pi}\sigma} \cdot e^{-\frac{(x_{t-1}-\mu)^2}{2\sigma^2}}(高斯分布)\\ &= \frac{1}{\sqrt{2\pi}\sqrt{{1-\bar{\alpha_{t-1}}}}} \cdot e^{-\frac{(x_{t-1}-\sqrt{\bar{\alpha_{t-1}}} \cdot x_{0})^2}{2\cdot ({1-\bar{\alpha_{t-1}}})}} (将\mu和\sigma带入) \end{split} \end{equation} p(xt1x0)=2π σ1e2σ2(xt1μ)2(高斯分布)=2π 1αt1ˉ 1e2(1αt1ˉ)(xt1αt1ˉ x0)2(μσ带入)

考虑公式3中的 p ( x t ∣ x 0 ) p(x_{t}|x_0) p(xtx0)

由公式 x t = α t ˉ ⋅ x 0 + 1 − α t ˉ ⋅ z x_{t}=\sqrt{\bar{\alpha_{t}}} \cdot x_{0} +\sqrt{1-\bar{\alpha_{t}}}\cdot z xt=αtˉ x0+1αtˉ z,且 x t ∼ N ( μ = α t ˉ ⋅ x 0 , σ 2 = 1 − α t ˉ ) x_{t}\sim N(\mu=\sqrt{\bar{\alpha_{t}}} \cdot x_{0},\sigma^2={1-\bar{\alpha_{t}}}) xtN(μ=αtˉ x0,σ2=1αtˉ),所以 p ( x t ∣ x 0 ) p(x_{t}|x_0) p(xtx0)可以写成下式
p ( x t ∣ x 0 ) = 1 2 π σ ⋅ e − ( x t − μ ) 2 2 σ 2 (高斯分布) = 1 2 π 1 − α t ˉ ⋅ e − ( x t − α t ˉ ⋅ x 0 ) 2 2 ⋅ ( 1 − α t ˉ ) ( 将 μ 和 σ 带入 ) \begin{equation} \begin{split} p(x_{t}|x_0) &= \frac{1}{\sqrt{2\pi}\sigma} \cdot e^{-\frac{(x_{t}-\mu)^2}{2\sigma^2}}(高斯分布)\\ &= \frac{1}{\sqrt{2\pi}\sqrt{{1-\bar{\alpha_{t}}}}} \cdot e^{-\frac{(x_{t}-\sqrt{\bar{\alpha_{t}}} \cdot x_{0})^2}{2\cdot ({1-\bar{\alpha_{t}}})}} (将\mu和\sigma带入) \end{split} \end{equation} p(xtx0)=2π σ1e2σ2(xtμ)2(高斯分布)=2π 1αtˉ 1e2(1αtˉ)(xtαtˉ x0)2(μσ带入)

将公式4,5,6带入公式3中,可以得到
p ( x t ∣ x t − 1 , x 0 ) ⋅ p ( x t − 1 ∣ x 0 ) p ( x t ∣ x 0 ) = 1 2 π 1 − α t ⋅ e − ( x t − α t x t − 1 ) 2 2 ⋅ ( 1 − α t ) ⋅ 1 2 π 1 − α t − 1 ˉ ⋅ e − ( x t − 1 − α t − 1 ˉ ⋅ x 0 ) 2 2 ⋅ ( 1 − α t − 1 ˉ ) 1 2 π 1 − α t ˉ ⋅ e − ( x t − α t ˉ ⋅ x 0 ) 2 2 ⋅ ( 1 − α t ˉ ) = 1 − α t ˉ 2 π 1 − α t ⋅ 1 − α t − 1 ˉ e ( − ( x t − α t x t − 1 ) 2 2 ⋅ ( 1 − α t ) ) + ( − ( x t − 1 − α t − 1 ˉ ⋅ x 0 ) 2 2 ⋅ ( 1 − α t − 1 ˉ ) ) − ( − ( x t − α t ˉ ⋅ x 0 ) 2 2 ⋅ ( 1 − α t ˉ ) ) = 1 − α t ˉ 2 π β t ⋅ 1 − α t − 1 ˉ e − 1 2 ⋅ [ ( x t − α t x t − 1 ) 2 β t + ( x t − 1 − α t − 1 ˉ ⋅ x 0 ) 2 1 − α t − 1 ˉ − ( x t − α t ˉ ⋅ x 0 ) 2 1 − α t ˉ ] = 1 − α t ˉ 2 π β t ⋅ 1 − α t − 1 ˉ e − 1 2 ⋅ [ x t 2 − 2 ⋅ x t ⋅ x t − 1 ⋅ α t + α t ⋅ x t − 1 2 β t + x t − 1 2 − 2 ⋅ x t − 1 ⋅ x 0 ⋅ α t − 1 ˉ + α t − 1 ˉ ⋅ x 0 2 1 − α t − 1 ˉ − x t 2 − 2 ⋅ x t ⋅ x 0 ⋅ α t ˉ + α t ˉ ⋅ x 0 2 1 − α t ˉ ] = 1 − α t ˉ 2 π β t ⋅ 1 − α t − 1 ˉ e − 1 2 ⋅ [ ( α t β t + 1 1 − α t − 1 ˉ ) ⋅ x t − 1 2 − ( 2 ⋅ x t α t β t + 2 ⋅ x 0 ⋅ α t − 1 ˉ 1 − α t − 1 ˉ ) ⋅ x t − 1 + C ( x t , x 0 ) ] \begin{equation} \begin{split} p(x_t|x_{t-1},x_0)\cdot \frac{p(x_{t-1}|x_0)}{p(x_t|x_0)} &= \frac{1}{\sqrt{2\pi}\sqrt{1-\alpha_t}} \cdot e^{-\frac{(x_t-\sqrt{\alpha_{t}}x_{t-1})^2}{2\cdot (1-\alpha_t)}}\cdot \frac{\frac{1}{\sqrt{2\pi}\sqrt{{1-\bar{\alpha_{t-1}}}}} \cdot e^{-\frac{(x_{t-1}-\sqrt{\bar{\alpha_{t-1}}} \cdot x_{0})^2}{2\cdot ({1-\bar{\alpha_{t-1}}})}}}{\frac{1}{\sqrt{2\pi}\sqrt{{1-\bar{\alpha_{t}}}}} \cdot e^{-\frac{(x_{t}-\sqrt{\bar{\alpha_{t}}} \cdot x_{0})^2}{2\cdot ({1-\bar{\alpha_{t}}})}}} \\&=\frac{\sqrt{{1-\bar{\alpha_{t}}}}}{\sqrt{2\pi} \sqrt{1-\alpha_t}\cdot \sqrt{{1-\bar{\alpha_{t-1}}}}} e^{({-\frac{(x_t-\sqrt{\alpha_{t}}x_{t-1})^2}{2\cdot (1-\alpha_t)}})+({-\frac{(x_{t-1}-\sqrt{\bar{\alpha_{t-1}}} \cdot x_{0})^2}{2\cdot ({1-\bar{\alpha_{t-1}}})}})-({-\frac{(x_{t}-\sqrt{\bar{\alpha_{t}}} \cdot x_{0})^2}{2\cdot ({1-\bar{\alpha_{t}}})}})}\\ &=\frac{\sqrt{{1-\bar{\alpha_{t}}}}}{\sqrt{2\pi} \sqrt{\beta_t}\cdot \sqrt{{1-\bar{\alpha_{t-1}}}}}e^{-\frac{1}{2}\cdot[{\frac{(x_t-\sqrt{\alpha_{t}}x_{t-1})^2}{\beta_t}}+{\frac{(x_{t-1}-\sqrt{\bar{\alpha_{t-1}}} \cdot x_{0})^2}{{1-\bar{\alpha_{t-1}}}}}-{\frac{(x_{t}-\sqrt{\bar{\alpha_{t}}} \cdot x_{0})^2}{{1-\bar{\alpha_{t}}}}}]}\\&=\frac{\sqrt{{1-\bar{\alpha_{t}}}}}{\sqrt{2\pi} \sqrt{\beta_t}\cdot \sqrt{{1-\bar{\alpha_{t-1}}}}}e^{-\frac{1}{2}\cdot[\frac{x_{t}^{2}-2\cdot x_t\cdot x_{t-1}\cdot \sqrt{\alpha_t}+\alpha_t\cdot x_{t-1}^2}{\beta_t}+\frac{x_{t-1}^2-2\cdot x_{t-1}\cdot x_0\cdot \sqrt{\bar{\alpha_{t-1}}}+\bar{\alpha_{t-1}}\cdot x_0^2}{1-\bar{\alpha_{t-1}}}-\frac{x_t^2-2\cdot x_t \cdot x_0 \cdot \sqrt{\bar{\alpha_t}}+\bar{\alpha_t}\cdot x_0^2}{1-\bar{\alpha_t}}]}\\&=\frac{\sqrt{{1-\bar{\alpha_{t}}}}}{\sqrt{2\pi} \sqrt{\beta_t}\cdot \sqrt{{1-\bar{\alpha_{t-1}}}}}e^{-\frac{1}{2}\cdot [(\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha_{t-1}}})\cdot x_{t-1}^2-(\frac{2\cdot x_t \sqrt{\alpha_t}}{\beta_t}+\frac{2\cdot x_0 \cdot \sqrt{\bar{\alpha_{t-1}}}}{1-\bar{\alpha_{t-1}}})\cdot x_{t-1}+C(x_t,x_0)]} \end{split} \end{equation} p(xtxt1,x0)p(xtx0)p(xt1x0)=2π 1αt 1e2(1αt)(xtαt xt1)22π 1αtˉ 1e2(1αtˉ)(xtαtˉ x0)22π 1αt1ˉ 1e2(1αt1ˉ)(xt1αt1ˉ x0)2=2π 1αt 1αt1ˉ 1αtˉ e(2(1αt)(xtαt xt1)2)+(2(1αt1ˉ)(xt1αt1ˉ x0)2)(2(1αtˉ)(xtαtˉ x0)2)=2π βt 1αt1ˉ 1αtˉ e21[βt(xtαt xt1)2+1αt1ˉ(xt1αt1ˉ x0)21αtˉ(xtαtˉ x0)2]=2π βt 1αt1ˉ 1αtˉ e21[βtxt22xtxt1αt +αtxt12+1αt1ˉxt122xt1x0αt1ˉ +αt1ˉx021αtˉxt22xtx0αtˉ +αtˉx02]=2π βt 1αt1ˉ 1αtˉ e21[(βtαt+1αt1ˉ1)xt12(βt2xtαt +1αt1ˉ2x0αt1ˉ )xt1+C(xt,x0)]

考虑公式3的左侧

p ( x t − 1 ∣ x t , x 0 ) = 1 2 π σ ⋅ e − ( x t − 1 − μ ) 2 2 σ 2 (高斯分布) = 1 2 π σ ⋅ e − 1 2 [ x t − 1 2 − 2 ⋅ μ ⋅ x t − 1 + μ 2 σ 2 ] = 1 2 π σ ⋅ e − 1 2 [ 1 σ 2 ⋅ x t − 1 2 − 2 ⋅ μ σ 2 ⋅ x t − 1 + μ 2 σ 2 ] \begin{equation} \begin{split} p(x_{t-1}|x_t,x_0) &= \frac{1}{\sqrt{2\pi}\sigma} \cdot e^{-\frac{(x_{t-1}-\mu)^2}{2\sigma^2}}(高斯分布)\\ &= \frac{1}{\sqrt{2\pi}\sigma} \cdot e^{-\frac{1}{2}[\frac{x_{t-1}^2-2\cdot \mu \cdot x_{t-1} + \mu^2}{\sigma^2}]} \\ &= \frac{1}{\sqrt{2\pi}\sigma} \cdot e^{-\frac{1}{2}[\frac{1}{\sigma^2}\cdot x_{t-1}^2-\frac{2\cdot \mu }{\sigma^2}\cdot x_{t-1} + \frac{\mu^2}{\sigma^2}]} \end{split} \end{equation} p(xt1xt,x0)=2π σ1e2σ2(xt1μ)2(高斯分布)=2π σ1e21[σ2xt122μxt1+μ2]=2π σ1e21[σ21xt12σ22μxt1+σ2μ2]

联系公式7和公式8,可以看出

1 − α t ˉ 2 π β t ⋅ 1 − α t − 1 ˉ e − 1 2 ⋅ [ ( α t β t + 1 1 − α t − 1 ˉ ) ⋅ x t − 1 2 − ( 2 ⋅ x t α t β t + 2 ⋅ x 0 ⋅ α t − 1 ˉ 1 − α t − 1 ˉ ) ⋅ x t − 1 + C ( x t , x 0 ) ] = 1 2 π σ ⋅ e − 1 2 [ 1 σ 2 ⋅ x t − 1 2 − 2 ⋅ μ σ 2 ⋅ x t − 1 + μ 2 σ 2 ] \begin{equation} \begin{split} \frac{\sqrt{{1-\bar{\alpha_{t}}}}}{\sqrt{2\pi} \sqrt{\beta_t}\cdot \sqrt{{1-\bar{\alpha_{t-1}}}}}e^{-\frac{1}{2}\cdot [(\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha_{t-1}}})\cdot x_{t-1}^2-(\frac{2\cdot x_t \sqrt{\alpha_t}}{\beta_t}+\frac{2\cdot x_0 \cdot \sqrt{\bar{\alpha_{t-1}}}}{1-\bar{\alpha_{t-1}}})\cdot x_{t-1}+C(x_t,x_0)]}= \frac{1}{\sqrt{2\pi}\sigma} \cdot e^{-\frac{1}{2}[\frac{1}{\sigma^2}\cdot x_{t-1}^2-\frac{2\cdot \mu }{\sigma^2}\cdot x_{t-1} + \frac{\mu^2}{\sigma^2}]} \end{split} \end{equation} 2π βt 1αt1ˉ 1αtˉ e21[(βtαt+1αt1ˉ1)xt12(βt2xtαt +1αt1ˉ2x0αt1ˉ )xt1+C(xt,x0)]=2π σ1e21[σ21xt12σ22μxt1+σ2μ2]
观察公式9,对应项系数应该相等,可以得到

1 − α t ˉ 2 π β t ⋅ 1 − α t − 1 ˉ = 1 2 π σ ⟺ σ = β t ⋅ 1 − α t − 1 ˉ 1 − α t ˉ = β t ⋅ ( 1 − α t − 1 ˉ ) ( 1 − α t ˉ ) α t β t + 1 1 − α t − 1 ˉ = 1 σ 2 ⟺ σ = β t ⋅ ( 1 − α t − 1 ˉ ) α t ⋅ ( 1 − α t − 1 ˉ ) + β t = β t ⋅ ( 1 − α t − 1 ˉ ) ( 1 − α t ⋅ α t − 1 ˉ ) = β t ⋅ ( 1 − α t − 1 ˉ ) ( 1 − α t ˉ ) \begin{equation} \begin{split} \frac{\sqrt{{1-\bar{\alpha_{t}}}}}{\sqrt{2\pi} \sqrt{\beta_t}\cdot \sqrt{{1-\bar{\alpha_{t-1}}}}}&= \frac{1}{\sqrt{2\pi}\sigma}\iff \sigma= \frac{\sqrt{\beta_t}\cdot \sqrt{{1-\bar{\alpha_{t-1}}}}}{\sqrt{{1-\bar{\alpha_{t}}}}} =\sqrt{\frac{\beta_t\cdot (1-\bar{\alpha_{t-1}})}{(1-\bar{\alpha_{t}})}}\\ \frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha_{t-1}}} &= \frac{1}{\sigma^2} \iff \sigma=\sqrt{\frac{\beta_t\cdot (1-\bar{\alpha_{t-1}})}{\alpha_t\cdot (1-\bar{\alpha_{t-1}})+\beta^t}}=\sqrt{\frac{\beta_t\cdot (1-\bar{\alpha_{t-1}})}{(1-\alpha_t\cdot \bar{\alpha_{t-1}})}}=\sqrt{\frac{\beta_t\cdot (1-\bar{\alpha_{t-1}})}{(1-\bar{\alpha_{t}})}} \end{split} \end{equation} 2π βt 1αt1ˉ 1αtˉ βtαt+1αt1ˉ1=2π σ1σ=1αtˉ βt 1αt1ˉ =(1αtˉ)βt(1αt1ˉ) =σ21σ=αt(1αt1ˉ)+βtβt(1αt1ˉ) =(1αtαt1ˉ)βt(1αt1ˉ) =(1αtˉ)βt(1αt1ˉ)
由公式10可以看出,这两种算法得到的 σ \sigma σ是相同的。接下来计算 μ \mu μ。同样观察公式9,对应项系数应该相等,可以得到
2 ⋅ μ σ 2 = ( 2 ⋅ x t α t β t + 2 ⋅ x 0 ⋅ α t − 1 ˉ 1 − α t − 1 ˉ ) ⇕ μ ⋅ ( α t β t + 1 1 − α t − 1 ˉ ) = ( α t ⋅ x t β t + x 0 ⋅ α t − 1 ˉ 1 − α t − 1 ˉ ) ⇕ μ ⋅ α t ⋅ ( 1 − α t − 1 ˉ ) + β t β t ⋅ ( 1 − α t − 1 ˉ ) = α t ⋅ ( 1 − α t − 1 ˉ ) ⋅ x t + β t ⋅ x 0 ⋅ α t − 1 ˉ β t ⋅ ( 1 − α t − 1 ˉ ) ⇕ μ = α t ⋅ ( 1 − α t − 1 ˉ ) ⋅ x t + β t ⋅ x 0 ⋅ α t − 1 ˉ α t − α t ⋅ α t − 1 ˉ + β t ⇕ μ = α t ⋅ ( 1 − α t − 1 ˉ ) ⋅ x t + β t ⋅ x 0 ⋅ α t − 1 ˉ 1 − α t ˉ ⇕ μ = α t ⋅ ( 1 − α t − 1 ˉ ) 1 − α t ˉ ⋅ x t + β t ⋅ α t − 1 ˉ 1 − α t ˉ ⋅ x 0 \begin{equation} \begin{split} \frac{2\cdot \mu }{\sigma^2}&=(\frac{2\cdot x_t \sqrt{\alpha_t}}{\beta_t}+\frac{2\cdot x_0 \cdot \sqrt{\bar{\alpha_{t-1}}}}{1-\bar{\alpha_{t-1}}})\\ &\Updownarrow \\ \mu \cdot ( \frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha_{t-1}}}) &= (\frac{\sqrt{\alpha_t}\cdot x_t}{\beta_t}+\frac{x_0 \cdot \sqrt{\bar{\alpha_{t-1}}}}{1-\bar{\alpha_{t-1}}}) \\ &\Updownarrow \\ \mu \cdot \frac{\alpha_t\cdot (1-\bar{\alpha_{t-1}})+\beta_t}{\beta_t\cdot (1-\bar{\alpha_{t-1}})}&=\frac{\sqrt{\alpha_t}\cdot(1-\bar{\alpha_{t-1}})\cdot x_t+\beta_t\cdot x_0\cdot \sqrt{\bar{\alpha_{t-1}}} }{\beta_t\cdot (1-\bar{\alpha_{t-1}})} \\ &\Updownarrow \\ \mu &= \frac{\sqrt{\alpha_t}\cdot(1-\bar{\alpha_{t-1}})\cdot x_t+\beta_t\cdot x_0\cdot \sqrt{\bar{\alpha_{t-1}}} }{\alpha_t-\alpha_t\cdot \bar{\alpha_{t-1}}+\beta_t} \\ &\Updownarrow \\ \mu&= \frac{\sqrt{\alpha_t}\cdot(1-\bar{\alpha_{t-1}})\cdot x_t+\beta_t\cdot x_0\cdot \sqrt{\bar{\alpha_{t-1}}} }{1-\bar{\alpha_t}} \\ &\Updownarrow \\ \mu&=\frac{\sqrt{\alpha_t}\cdot(1-\bar{\alpha_{t-1}})}{1-\bar{\alpha_t}}\cdot x_t+\frac{\beta_t\cdot \sqrt{\bar{\alpha_{t-1}}}}{1-\bar{\alpha_t}} \cdot x_0 \end{split} \end{equation} σ22μμ(βtαt+1αt1ˉ1)μβt(1αt1ˉ)αt(1αt1ˉ)+βtμμμ=(βt2xtαt +1αt1ˉ2x0αt1ˉ )=(βtαt xt+1αt1ˉx0αt1ˉ )=βt(1αt1ˉ)αt (1αt1ˉ)xt+βtx0αt1ˉ =αtαtαt1ˉ+βtαt (1αt1ˉ)xt+βtx0αt1ˉ =1αtˉαt (1αt1ˉ)xt+βtx0αt1ˉ =1αtˉαt (1αt1ˉ)xt+1αtˉβtαt1ˉ x0
由公式10和公式11可以看出, x t − 1 x_{t-1} xt1是一个高斯分布, x t − 1 ∼ N ( μ , σ ) x_{t-1} \sim N(\mu, \sigma) xt1N(μ,σ)。其中 σ \sigma σ仅仅与权重 α \alpha α有关,而 μ \mu μ x 0 x_0 x0 x t x_t xt有关。同时, x 0 x_0 x0 x t x_t xt之间有公式1所示的关系。由公式1可以知道
x 0 = 1 α t ˉ ⋅ ( x t − 1 − α t ˉ ⋅ z t ) \begin{equation} \begin{split} x_0=\frac{1}{\sqrt{\bar{\alpha_t}}}\cdot(x_t-\sqrt{1-\bar{\alpha_t}}\cdot z_t) \end{split} \end{equation} x0=αtˉ 1(xt1αtˉ zt)
将公式12带入公式11中,可以得到
μ = α t ⋅ ( 1 − α t − 1 ˉ ) 1 − α t ˉ ⋅ x t + β t ⋅ α t − 1 ˉ 1 − α t ˉ ⋅ [ 1 α t ˉ ⋅ ( x t − 1 − α t ˉ ⋅ z t ) ] = 1 α t ⋅ α t ⋅ { α t ⋅ ( 1 − α t − 1 ˉ ) 1 − α t ˉ ⋅ x t + β t ⋅ α t − 1 ˉ 1 − α t ˉ ⋅ [ 1 α t ˉ ⋅ ( x t − 1 − α t ˉ ⋅ z t ) ] } = 1 α t ⋅ { α t ⋅ ( 1 − α t − 1 ˉ ) 1 − α t ˉ ⋅ x t + β t ⋅ α t ˉ 1 − α t ˉ ⋅ [ 1 α t ˉ ⋅ ( x t − 1 − α t ˉ ⋅ z t ) ] } = 1 α t ⋅ [ α t ⋅ ( 1 − α t − 1 ˉ ) 1 − α t ˉ ⋅ x t + β t 1 − α t ˉ ⋅ ( x t − 1 − α t ˉ ⋅ z t ) ] = 1 α t ⋅ [ α t ⋅ ( 1 − α t − 1 ˉ ) + β t 1 − α t ˉ ⋅ x t − β t 1 − α t ˉ ⋅ z t ] = 1 α t ⋅ [ α t − α t ⋅ α t − 1 ˉ + β t 1 − α t ˉ ⋅ x t − β t 1 − α t ˉ ⋅ z t ] = 1 α t ⋅ [ x t − β t 1 − α t ˉ ⋅ z t ] \begin{equation} \begin{split} \mu&=\frac{\sqrt{\alpha_t}\cdot(1-\bar{\alpha_{t-1}})}{1-\bar{\alpha_t}}\cdot x_t+\frac{\beta_t\cdot \sqrt{\bar{\alpha_{t-1}}}}{1-\bar{\alpha_t}} \cdot [\frac{1}{\sqrt{\bar{\alpha_t}}}\cdot(x_t-\sqrt{1-\bar{\alpha_t}}\cdot z_t)]\\ &= \frac{1}{\sqrt{\alpha_t}}\cdot \sqrt{\alpha_t} \cdot \{ \frac{\sqrt{\alpha_t}\cdot(1-\bar{\alpha_{t-1}})}{1-\bar{\alpha_t}}\cdot x_t+\frac{\beta_t\cdot \sqrt{\bar{\alpha_{t-1}}}}{1-\bar{\alpha_t}} \cdot [\frac{1}{\sqrt{\bar{\alpha_t}}}\cdot(x_t-\sqrt{1-\bar{\alpha_t}}\cdot z_t)] \}\\ &= \frac{1}{\sqrt{\alpha_t}}\cdot \{\frac{\alpha_t\cdot (1-\bar{\alpha_{t-1}})}{1-\bar{\alpha_{t}}}\cdot x_t +\frac{\beta_t\cdot \sqrt{\bar{\alpha_t}}}{1-\bar{\alpha_t}} \cdot [\frac{1}{\sqrt{\bar{\alpha_t}}}\cdot(x_t-\sqrt{1-\bar{\alpha_t}}\cdot z_t)]\} \\ &= \frac{1}{\sqrt{\alpha_t}}\cdot [\frac{\alpha_t\cdot (1-\bar{\alpha_{t-1}})}{1-\bar{\alpha_{t}}}\cdot x_t +\frac{\beta_t}{1-\bar{\alpha_t}} \cdot (x_t-\sqrt{1-\bar{\alpha_t}}\cdot z_t)]\\ &=\frac{1}{\sqrt{\alpha_t}}\cdot [\frac{\alpha_t\cdot (1-\bar{\alpha_{t-1}})+\beta_t}{1-\bar{\alpha_t}}\cdot x_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha_t}}}\cdot z_t]\\ &=\frac{1}{\sqrt{\alpha_t}}\cdot [\frac{\alpha_t -\alpha_t \cdot \bar{\alpha_{t-1}}+\beta_t}{1-\bar{\alpha_t}}\cdot x_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha_t}}}\cdot z_t]\\ &=\frac{1}{\sqrt{\alpha_t}}\cdot [x_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha_t}}}\cdot z_t] \end{split} \end{equation} μ=1αtˉαt (1αt1ˉ)xt+1αtˉβtαt1ˉ [αtˉ 1(xt1αtˉ zt)]=αt 1αt {1αtˉαt (1αt1ˉ)xt+1αtˉβtαt1ˉ [αtˉ 1(xt1αtˉ zt)]}=αt 1{1αtˉαt(1αt1ˉ)xt+1αtˉβtαtˉ [αtˉ 1(xt1αtˉ zt)]}=αt 1[1αtˉαt(1αt1ˉ)xt+1αtˉβt(xt1αtˉ zt)]=αt 1[1αtˉαt(1αt1ˉ)+βtxt1αtˉ βtzt]=αt 1[1αtˉαtαtαt1ˉ+βtxt1αtˉ βtzt]=αt 1[xt1αtˉ βtzt]

由公式10和13就可以计算出由 x t x_t xt计算 x t − 1 x_{t-1} xt1的方差和均值分别为
σ = β t ⋅ ( 1 − α t − 1 ˉ ) ( 1 − α t ˉ ) μ = 1 α t ⋅ ( x t − β t 1 − α t ˉ ⋅ z t ) \begin{equation} \begin{split} \sigma&=\sqrt{\frac{\beta_t\cdot (1-\bar{\alpha_{t-1}})}{(1-\bar{\alpha_{t}})}}\\ \mu&=\frac{1}{\sqrt{\alpha_t}}\cdot (x_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha_t}}}\cdot z_t) \end{split} \end{equation} σμ=(1αtˉ)βt(1αt1ˉ) =αt 1(xt1αtˉ βtzt)

这里的 z t z_t zt是在由 x 0 x_0 x0计算 x t x_t xt时引入的,是不可计算的,因此diffusion model的目的就是通过 x t x_t xt t t t预测出 z t z_t zt

代码:

import math
from inspect import isfunction
from functools import partial
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from einops import rearrange
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torchvision.utils import save_image
from torch.optim import Adam
from pathlib import Path
from torchvision import transforms as T
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from datasets import Dataset, load_dataset, load_from_disk
import warnings
warnings.filterwarnings("ignore")def exists(x):return x is not Nonedef default(val, d):if exists(val):return valelse:return d() if isfunction(d) else dclass Residual(nn.Module):def __init__(self, fn):super().__init__()self.fn = fndef forward(self, x, *args, **kwargs):return self.fn(x, *args, **kwargs) + xdef Upsample(dim):return nn.ConvTranspose2d(dim, dim, 4, 2, 1)def Downsample(dim):return nn.Conv2d(dim, dim, 4, 2, 1)class SinusoidaPositionEmbedings(nn.Module):def __init__(self, dim):super().__init__()self.dim = dimdef forward(self, time):device = time.devicehalf_dim = self.dim // 2embeddings = math.log(10000) / (half_dim-1)embeddings = torch.exp(torch.arange(half_dim, device = device) * -embeddings)embeddings = time[:, None] * embeddings[None, :]embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)return embeddingsclass Block(nn.Module):def __init__(self, dim, dim_out, groups = 8):super().__init__()self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)self.norm = nn.GroupNorm(groups, dim_out)self.act = SiLU()def forward(self, x, scale_shift = None):x = self.proj(x)x = self.norm(x)if exists(scale_shift):scale, shift = scale_shiftx = x * (scale + 1) + shiftx = self.act(x)return xclass SiLU(nn.Module):@staticmethoddef forward(x):return x + torch.sigmoid(x)class ResnetBlock(nn.Module):def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 4):super().__init__()self.mlp = (nn.Sequential(SiLU(), nn.Linear(time_emb_dim, dim_out) if exists(time_emb_dim) else None))self.block1 = Block(dim, dim_out, groups = groups)self.block2 = Block(dim_out, dim_out, groups = groups)self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()def forward(self, x, time_emb = None):h = self.block1(x)if exists(self.mlp) and exists(time_emb):time_emb = self.mlp(time_emb)h = rearrange(time_emb, 'b c -> b c 1 1') + hh = self.block2(h)return h + self.res_conv(x)class ConvNextBlock(nn.Module):def __init__(self, dim, dim_out, *, time_emb_dim = None, mult = 2, norm = True):super.__init__()self.mlp = (nn.Sequential(nn.GELU), nn.Linear(time_emb_dim, dim)) if exists(time_emb_dim) else Noneself.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)self.net = nn.Sequential(nn.GroupNorm(1, dim) if norm else nn.Identity(),nn.Conv2d(dim, dim_out * mult, 3, padding=1),nn.GELU(),nn.GroupNorm(1, dim_out * mult),nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),)self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()def forward(self, x, time_emb = None):h = self.ds_conv(x)if exists(self.mlp) and exists(time_emb):assert exists(time_emb), 'time embedding must be passed in'condition = self.mlp(time_emb)h = h + rearrange(condition, 'b c -> b c 1 1')h = self.net(h)return h + self.res_conv(x)class Attention(nn.Module):def __init__(self, dim, heads = 4, dim_head = 32):super().__init__()self.scale = dim_head ** -0.5self.heads = headshidden_dim = dim_head * headsself.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)self.to_out = nn.Conv2d(hidden_dim, dim, 1)def forward(self, x):b, c, h, w = x.shapeqkv = self.to_qkv(x).chunk(3, dim = 1)q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)q = q * self.scalesim = einsum('b h d i, b h d j -> b h i j', q, k)attn = sim.softmax(dim = -1)out = einsum('b h i j, b h d j -> b h i d', attn, v)out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)return self.to_out(out)class LinearAttention(nn.Module):def __init__(self, dim, heads = 4, dim_head = 32):super().__init__()self.scale = dim_head ** -0.5self.heads = headshidden_dim = dim_head * headsself.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim))def forward(self, x):b, c, h, w = x.shapeqkv = self.to_qkv(x).chunk(3, dim = 1)q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)q = q.softmax(dim = -2)k = k.softmax(dim = -1)q = q * self.scalecontext = torch.einsum('b h d n, b h e n -> b h d e', k, v)out = torch.einsum('b h d e, b h d n -> b h e n', context, q)out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)return self.to_out(out)class PreNorm(nn.Module):def __init__(self, dim, fn):super().__init__()self.fn = fnself.norm = nn.GroupNorm(1, dim)def forward(self, x):x = self.norm(x)return self.fn(x)class Unet(nn.Module):def __init__(self, dim, init_dim = None, out_dim = None, dim_mults = (1, 2, 4, 8),channels = 3, with_time_emb = True, resnet_block_groups = 4, use_convnext = False, convnext_mult = 2):super().__init__()self.channels = channelsinit_dim = default(init_dim, dim // 3 * 2)self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)dims = [init_dim, *map(lambda m: dim * m, dim_mults)]in_out = list(zip(dims[:-1], dims[1:]))if use_convnext:block_klass = partial(ConvNextBlock, mult = convnext_mult)else:block_klass = partial(ResnetBlock, groups = resnet_block_groups)if with_time_emb:time_dim = dim * 4self.time_mlp = nn.Sequential(SinusoidaPositionEmbedings(dim),nn.Linear(dim, time_dim),nn.GELU(),nn.Linear(time_dim, time_dim))else:time_dim = Noneself.time_mlp = Noneself.downs = nn.ModuleList([])self.ups = nn.ModuleList([])num_resolutions = len(in_out)for ind, (dim_in, dim_out) in enumerate(in_out):is_last = ind >= (num_resolutions - 1)self.downs.append(nn.ModuleList([block_klass(dim_in, dim_out, time_emb_dim = time_dim),block_klass(dim_out, dim_out, time_emb_dim = time_dim),Residual(PreNorm(dim_out, LinearAttention(dim_out))),Downsample(dim_out) if not is_last else nn.Identity(),]))mid_dim = dims[-1]self.mim_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):is_last = ind >= (num_resolutions - 1)self.ups.append(nn.ModuleList([block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),block_klass(dim_in, dim_in, time_emb_dim=time_dim),Residual(PreNorm(dim_in, LinearAttention(dim_in))),Upsample(dim_in) if not is_last else nn.Identity(),]))out_dim = default(out_dim, channels)self.final_conv = nn.Sequential(block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1))def forward(self, x, time):x = self.init_conv(x)t = self.time_mlp(time) if exists(self.time_mlp) else Noneh = []for block1, block2, attn, downsample in self.downs:x = block1(x, t)x = block2(x, t)x = attn(x)h.append(x)x = downsample(x)x = self.mim_block1(x, t)x = self.mid_attn(x)x = self.mid_block2(x, t)for block1, block2, attn, upsample in self.ups:x = torch.cat((x, h.pop()), dim=1)x = block1(x, t)x = block2(x, t)x = attn(x)x = upsample(x)return self.final_conv(x)def consine_beta_schedule(timesteps, s = 0.008):steps = timesteps + 1x = torch.linspace(0, timesteps, steps)alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2alphas_cumprod = alphas_cumprod / alphas_cumprod[0]betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])return torch.clip(betas, 0.0001, 0.9999)def linear_beta_schedule(timesteps):beta_start = 0.0001beta_end = 0.02return torch.linspace(beta_start, beta_end, timesteps)def quadratic_beta_schedule(timesteps):beta_start = 0.0001beta_end = 0.02return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2def sigmoid_beta_schedule(timesteps):beta_start = 0.0001beta_end = 0.02betas = torch.linspace(-6, 6, timesteps)return torch.sigmoid(betas) * (beta_end - beta_start) + beta_starttimesteps = 200
betas = linear_beta_schedule(timesteps = timesteps)alphas = 1 - betas
alphas_cumprod = torch.cumprod(alphas, axis = 0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt((1. - alphas_cumprod))
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)def extract(a, t, x_shape):batch_size = t.shape[0]out = a.gather(-1, t.cpu())return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)def q_sample(x_start, t, noise = None):if noise is None:noise = torch.randn_like(x_start)sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape)return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noisedef get_noisy_image(x_start, t):x_noisy = q_sample(x_start, t=t)noisy_image = reverse_transform(x_noisy.squeeze())return noisy_imagedef p_losses(denoise_model, x_start, t, noise = None, loss_type = 'l1'):if noise is None:noise = torch.randn_like(x_start)x_noise = q_sample(x_start = x_start, t=t, noise = noise)predicted_noise = denoise_model(x_noise, t)if loss_type=='l1':loss = F.l1_loss(noise, predicted_noise)elif loss_type=='l2':loss = F.l2_loss(noise, predicted_noise)elif loss_type=='huber':loss = F.smooth_l1_loss(noise, predicted_noise)else:raise NotImplemented()return lossdef transforms(examples):# examples['image'][0].show()transform = Compose([T.RandomHorizontalFlip(),T.ToTensor(),T.Lambda(lambda t : (t * 2) - 1)])examples['pixel_values'] = [transform(image.convert('L')) for image in examples['image']]del examples['image']return examples@torch.no_grad()
def p_sample(model, x, t, t_index):betas_t = extract(betas, t, x.shape)sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x.shape)sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)model_mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t)if t_index == 0:return model_meanelse:posterior_variance_t = extract(posterior_variance, t, x.shape)noise = torch.randn_like(x)return model_mean + torch.sqrt(posterior_variance_t) * noisedef p_sample_loop(model, shape):device = next(model.parameters()).deviceb = shape[0]img = torch.randn(shape, device=device)imgs = []for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)imgs.append(img.cpu().numpy())return imgs@torch.no_grad()
def sample(model, image_size, batch_size, channels = 3):return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))def num_to_groups(num, divisor):groups = num // divisorremainder = num % divisorarr = [divisor] * groupsif remainder > 0:arr.append(remainder)return arrdef train():image_size = 28channels = 1batch_size = 128# dataset = load_dataset('fashion_mnist')# dataset.save_to_disk('/Users/mac/Desktop/DALLE2-pytorch/data')dataset = load_from_disk('/Users/mac/Desktop/DALLE2-pytorch/data')transformed_dataset = dataset.with_transform(transforms).remove_columns('label')dataload = DataLoader(transformed_dataset['train'], batch_size = batch_size, shuffle = True)device = 'cuda' if torch.cuda.is_available() else 'cpu'model = Unet(dim = image_size, channels = channels, dim_mults = (1, 2, 4))model.to(device)optimizer = Adam(model.parameters(), lr = 1e-3)result_folder = Path('/results')save_and_sample_every = 1000epochs = 1for epoch in range(epochs):for step, batch in enumerate(dataload):optimizer.zero_grad()batch_size = batch['pixel_values'].shape[0]batch = batch['pixel_values'].to(device)t = torch.randint(0, timesteps, (batch_size, ), device = device).long()loss = p_losses(model, batch, t, loss_type='huber')if step % 1 == 0:print(f'loss : {loss.item()}')loss.backward()optimizer.step()if step != 0 and step % save_and_sample_every == 0:milestone = step // save_and_sample_everybatches = num_to_groups(4, batch_size)all_images_list = list(map(lambda n: sample(model, batch_size=n, channels=channels), batches))all_images = torch.cat(all_images_list, dim=0)all_images = (all_images + 1) * 0.5save_image(all_images, str(result_folder / f'sample-{milestone}.png'), nrow = 6)samples = sample(model, image_size=image_size, batch_size=64, channels=channels)if __name__ == '__main__':train()

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com