这是用户在 2024-6-21 22:21 为 https://zhuanlan.zhihu.com/p/704536437 保存的双语快照页面,由 沉浸式翻译 提供双语支持。了解如何保存?
CleanDiffuser(二)一起来写一个Cross-solver的Diffusion Model!

CleanDiffuser(二)一起来写一个Cross-solver的Diffusion Model!

17 人赞同了该文章
发布于 2024-06-20 19:26・IP 属地天津 ,编辑于 2024-06-20 19:37・IP 属地天津

1. 什么是Cross-solver的Diffusion Model?

在一些同学的概念里DDPM/DDIM/DPM-Solvers等等都是“独立”的工作,它们之间并没有直接的联系。实际上这些模型都能统一到一个框架,让我只需要训练一个模型,就可以在推理时随意切换不同的solver,以及不同的sampling steps。这样的代码实现能让我们的分析十分方便,也能让我们的模型更加灵活。CleanDiffuser中的模型都支持这个功能,在这篇文章里我们一起来动手实现一下叭!

2. 理论基础:我们该如何统一

设想我们想用 Diffusion Model 生成目标分布 q0(x)q_0(x) 的样本。DiffusionSDE 通过 noise schedule {αt,σt}t[0,T]\{\alpha_t, \sigma_t\}_{t\in[0,T]} 定义了前向过程 {xt}t[0,T]\{x_t\}_{t\in[0,T]},任意时刻 t[0,T]\forall t\in[0,T]xtx_t 都服从:

(1)xt=αtx0+σtϵ, ϵN(0,I)x_t=\alpha_t x_0 + \sigma_t \epsilon, ~\epsilon\in\mathcal N(0,I) \tag{1}

这个过程可以表达成等价的随机微分方程 SDE,其中 wtw_t 是标准维纳过程:

(2)dxt=f(t)xtdt+g(t)dwt, x0q0(x0){\rm d}x_t = f(t) x_t {\rm d}t + g(t){\rm d}w_t, ~x_0\sim q_0(x_0) \tag{2}

怎么理解这个等价?x0x_0 开始求解(2)到 tt 时刻得到的 xtx_t,和直接用(1)对 x0x_0 加噪得到的 xtx_t 分布是相同的。因为他们是等价的,所以显然 f,gf,gα,σ\alpha,\sigma 之间一一对应。

(3)f(t)=dlogαtdt, g2(t)=dσt2dt2σt2dlogαtdtf(t)=\frac{{\rm d}\log\alpha_t}{{\rm d}t},~g^2(t)=\frac{{\rm d}\sigma_t^2}{{\rm d}t}-2\sigma_t^2\frac{{\rm d}\log\alpha_t}{{\rm d}t} \tag{3}

这里面的 f,g,α,σf,g,\alpha,\sigma 都是我们自己定义的已知量。既然两者是等价的,我们为什么要关注SDE形式呢?因为我们关心不是前向过程,而是逆向过程,我们希望从 qT(xT)q_T(x_T) (近似高斯噪声) 逆向生成 q0(x0)q_0(x_0) 的样本。DiffusionSDE 具有 SDE/ODE 两种逆向过程形式,它们具有相同的边际分布:

(4)dxt=[f(t)xtg2(t)xlogqt(xt)]dt+g(t)dw~t{\rm d}x_t=\left[f(t)x_t-g^2(t)\nabla_x\log q_t(x_t)\right]{\rm d}t+g(t){\rm d}\tilde w_t \tag{4}

(5)dxt=[f(t)xt12g2(t)xlogqt(xt)]dt{\rm d}x_t=\left[f(t)x_t-\frac{1}{2}g^2(t)\nabla_x\log q_t(x_t)\right]{\rm d}t \tag{5}

通过求解(4)或(5)我们就能从 qT(xT)q_T(x_T) 样本逆向生成 q0(x0)q_0(x_0) 样本。可惜这个微分方程里面存在一个未知项 xlogqt(xt)\nabla_x\log q_t(x_t)(称为 score function),所以秉承深度学习的精神,我们用神经网络来近似这个项,我们不妨先看看 xlogqt(xt|x0)\nabla_x\log q_t(x_t|x_0) 是什么样子的:

(6)xlogqt(xt|x0)=xlog[Cexp((xtαtx0)22σt2)]=1σt2(xtαtx0)=ϵtσt\nabla_x\log q_t(x_t|x_0)=\nabla_x\log\left[C\exp\left(-\frac{(x_t-\alpha_t x_0)^2}{2\sigma_t^2}\right)\right]=-\frac{1}{\sigma_t^2}(x_t-\alpha_t x_0)=-\frac{\epsilon_t}{\sigma_t} \tag{6}

其中 CC 是和 xx 无关的常量。通过最小化 Eqt(xt|x0)[d(sθ(xt,t),xlogqt(xt|x0))]\mathbb E_{q_t(x_t|x_0)}\left[d(s_\theta(x_t,t),\nabla_x\log q_t(x_t|x_0) )\right] 就能让神经网络 sθ(xt,t)s_\theta(x_t,t) 近似 xlogqt(xt)\nabla_x\log q_t(x_t)dd 是随便某种距离度量。在实践中一般不会用神经网络直接拟合 score function,而是用神经网络拟合单步添加的噪声,或直接预测“干净”数据:

(7)ϵθ(xt,t)σtxlogqt(xt)=ϵtL(θ)=Eqt(xt|x0)[d(ϵθ(xt,t),ϵt)]\epsilon_\theta(x_t, t) \approx -\sigma_t\nabla_x\log q_t(x_t)=\epsilon_t \\\mathcal L(\theta)=\mathbb E_{q_t(x_t|x_0)}\left[d(\epsilon_\theta(x_t,t),\epsilon_t)\right] \tag{7}

(8)xθ(xt,t)αtxtσt2xlogqt(xt)=αtxt+σtϵt=x0L(θ)=Eqt(xt|x0)[d(xθ(xt,t),x0)]x_\theta(x_t, t) \approx \alpha_t x_t-\sigma^2_t\nabla_x\log q_t(x_t)=\alpha_t x_t+\sigma_t\epsilon_t=x_0 \\\mathcal L(\theta)=\mathbb E_{q_t(x_t|x_0)}\left[d(x_\theta(x_t,t),x_0)\right] \tag{8}

关于神经网络训练的小结: 无论如何我们都是只是用神经网络估计逆向 SDE/ODE 中的未知项,在 DiffusionSDE 里面就是此处的 score function xlogqt(xt)\nabla_x\log q_t(x_t),不同的 Loss 设计也不过就是预测 score function 加个或乘个已知项。完成训练后我们就可以开始对 SDE/ODE 进行求解了,换句话说,目前为止我们还从未开始尝试求解,跟 Solver 都还没有关系! 看下面的三个例子,我们这样完成训练后就可以随便在 Solver 间切换了。

如果我们用的是 DiffusionSDE(VP-SDE),那么神经网络要拟合的未知项就是 scaled score function。将 reverse SDE 一阶离散用 Euler 求解就是 DDPM,将 reverse ODE 一阶离散用 Euler 求解就是 DDIM,将 reverse SDE/ODE 的线性部分求精确解,非线性部分取泰勒展开k阶估计就是k阶 DPM-Solver。训练一个神经网络我们就可以在这些 Solver 里面随意切换。

如果我们用的是 EDM,那么神经网络要拟合的未知项就是“干净”的数据(不管 preconditioning 的话)。将 reverse ODE 用 Euler 求解就是 EDM原文的 Euler,用 2nd order Heun 求解,就是 EDM 默认的 Heun。训练一个神经网络我们就可以在这些 Solver 里面随意切换。
如果我们用的是 Rectified flow,那么神经网络要拟合的未知项就是 Drift force。将 reverse ODE 用 Euler 求解就是 Rectified flow 的标准复现,我们当然也可以用 RK45/Heun 等等求解器。训练一个神经网络我们就可以在这些 Solver 里面随意切换。

当然,在这篇文章中我们主要关注的是 DiffusionSDE,也就是前文一直介绍的这种 formulation,所以我们假设已经训练好了一个神经网络 ϵθ(xt,t)σtxlogqt(xt)\epsilon_\theta(x_t,t)\approx-\sigma_t\nabla_x\log q_t(x_t),我们来看看在采样的时候是怎么变成 DDPM/DDIM/DPM-Solver的。

3. 在各种 Solvers 间穿梭自如

3.1 DDPM & DDIM

我们先将 reverse ODE 做一阶离散化,假设从时刻 τ=s\tau=s 开始,我们希望求解 τ=t\tau=t 时刻的 xtx_t

(9)xtxs=[f(s)xs+g2(s)2σsϵθ(xs,s)]Δt(10)xtxs=[(αtαs1)xs+αt(σtαtσsαs)ϵθ(xs,s)](11)xt=αtαsxs+αt(σtαtσsαs)ϵθ(xs,s)\begin{align}x_t-x_s&=\left[f(s)x_s+\frac{g^2(s)}{2\sigma_s}\epsilon_\theta(x_s,s)\right]\Delta t \tag{9} \\x_t-x_s&=\left[\left(\frac{\alpha_t}{\alpha_s}-1\right)\cdot x_s+\alpha_t\left(\frac{\sigma_t}{\alpha_t}-\frac{\sigma_s}{\alpha_s}\right)\cdot\epsilon_\theta(x_s,s)\right] \tag{10} \\x_t &= \frac{\alpha_t}{\alpha_s}\cdot x_s+\alpha_t\left(\frac{\sigma_t}{\alpha_t}-\frac{\sigma_s}{\alpha_s}\right)\cdot\epsilon_\theta(x_s,s) \tag{11}\end{align}

我们再看看 DDIM 的递推解噪过程,注意因为 DDIM 论文里和本文的符号系统有冲突,DDIM 中的 α\alpha 其实等价于本文的 α2\alpha^2,所以我们统一用 aa 来表示 DDIM 中的 α\alpha

(12)xt=at(xs1asϵθ(xs)as)+1atϵθ(xt)(13)xt=αt(xsσsϵθ(xs)αs)+σtϵθ(xt)(14)xt=αtαsxs+αt(σtαtσsαs)ϵθ(xs,s)\begin{align}x_t&=\sqrt{a_t}\left(\frac{x_s-\sqrt{1-a_s}\cdot\epsilon_\theta(x_s)}{\sqrt{a_s}}\right)+\sqrt{1-a_t}\cdot\epsilon_\theta(x_t) \tag{12} \\x_t&=\alpha_t\left(\frac{x_s-\sigma_s\cdot\epsilon_\theta(x_s)}{\alpha_s}\right)+\sigma_t\cdot\epsilon_\theta(x_t) \tag{13} \\x_t &= \frac{\alpha_t}{\alpha_s}\cdot x_s+\alpha_t\left(\frac{\sigma_t}{\alpha_t}-\frac{\sigma_s}{\alpha_s}\right)\cdot\epsilon_\theta(x_s,s) \tag{14} \\\end{align}

可以看到是完全一样的,DDIM 就是在用 Δα/Δt\Delta\alpha/\Delta t 估计 dα/dt{\rm d}\alpha/{\rm d}t,然后等式两边都乘上 Δt\Delta t 就变成递推式了。所有有兴趣的同学也可以尝试一下直接计算出 dα/dt{\rm d}\alpha/{\rm d}t 再显式乘上 Δt\Delta t,EDM 就属于这种做法。
同理,要得到 DDPM,只需要在 DDIM 的基础上补一个:

(15)cs=1at1as1asat=σtσs1αs2αt2c_s=\sqrt{\frac{1-a_{t}}{1-a_s}}\cdot\sqrt{1-\frac{a_s}{a_t}}=\frac{\sigma_t}{\sigma_s}\sqrt{1-\frac{\alpha^2_s}{\alpha^2_t}} \tag{15}

使用:

(16)xt=αt(xsσsϵθ(xs)αs)+σt2cs2ϵθ(xt)+csϵx_t=\alpha_t\left(\frac{x_s-\sigma_s\cdot\epsilon_\theta(x_s)}{\alpha_s}\right)+\sqrt{\sigma^2_t-c^2_s}\cdot\epsilon_\theta(x_t)+c_s\epsilon \tag{16}

就能得到 DDPM 的递推式了。这里的 ϵ\epsilon 是标准高斯噪声。

3.2 DPM-Solvers

(如果你对这些内容感到陌生且不懂,不如来康康之前的 DPM-Solver 的文章吧 80岁的公园大爷问我 DPMsolver 是什么然后我给他写了这个他说哦我看懂了 - 知乎 (zhihu.com)

DPM-Solver 注意到了 reverse ODE 的半线性,所以直接计算出了从 τ=s\tau=sτ=t\tau=t 的精确解,显然,其中的线性部分是可以直接计算的,而非线性部分能写出来,但算不了:

(17)xt=αtαsxsαtstdλdτστατϵθ(xτ,τ)dτx_t = \frac{\alpha_t}{\alpha_s}x_s-\alpha_t\int_s^t\frac{{\rm d}\lambda}{{\rm d}\tau}\frac{\sigma_\tau}{\alpha_\tau}\epsilon_\theta(x_\tau,\tau){\rm d}\tau \tag{17}

其中 λ=log(α/σ)\lambda=\log(\alpha/\sigma) 是 log信噪比。因此非线性部分可以用泰勒展开来估计,这就是 DPM-Solver 的核心思想。我们对非线性项进行展开,得到:

(18)xt=αtαsxsαtn=0k1ϵθ(n)(xs,s)λsλteλ(λλs)nn!dλ+O((λtλs)k+1)x_t = \frac{\alpha_t}{\alpha_s}x_s-\alpha_t\sum_{n=0}^{k-1}\epsilon_\theta^{(n)}(x_s,s)\int_{\lambda_s}^{\lambda_t}e^{-\lambda}\frac{(\lambda-\lambda_s)^n}{n!}{\rm d}\lambda+\mathcal O((\lambda_t-\lambda_s)^{k+1}) \tag{18}

这里虽然是一个一般形式,但是实际用的时候不会用很高阶的,因为会有不稳定的问题。我们直接来看看一阶和二阶的递推式,先是 ODE 的一阶形式,ht=λtλsh_t=\lambda_t-\lambda_s

ODE-DPM-Solver-1

(19)xt=αtαsxsσt(eht1)ϵθ(xs,s)x_t=\frac{\alpha_t}{\alpha_s}x_s-\sigma_t(e^{h_t}-1)\epsilon_\theta(x_s,s) \tag{19}

λ\lambda 带进去替换掉会发现和(14)是完全一样的,所以 DPM-Solver 的一阶递推式就是 DDIM 的递推式。二阶需要计算神经网络的一阶导,这个事情比较困难,但是 DPM-Solver 证明了可以用两次的直接计算结果来估计这个一阶导的结果,这里就不展开,直接给出二阶递推式,假设 t<r<st < r < s 且记 r1=λrλshtr_1=\frac{\lambda_r-\lambda_s}{h_t}

ODE-DPM-Solver-2M

(20)xt=αtαsxsσt(eht1)ϵθ(xs,s)σt(eht1)2r1(ϵθ(xr,r)ϵθ(xs,s))x_t=\frac{\alpha_t}{\alpha_s}x_s-\sigma_t(e^{h_t}-1)\epsilon_\theta(x_s,s)-\frac{\sigma_t(e^{h_t}-1)}{2r_1}\left(\epsilon_\theta(x_r,r)-\epsilon_\theta(x_s,s)\right) \tag{20}

那么很显然,对于 SDE,我们完全可以如法炮制地玩一遍,同样得到一阶和二阶的递推式:

SDE-DPM-Solver-1

(19)xt=αtαsxs2σt(eht1)ϵθ(xs,s)+σte2ht1ϵx_t=\frac{\alpha_t}{\alpha_s}x_s-2\sigma_t(e^{h_t}-1)\epsilon_\theta(x_s,s)+\sigma_t\sqrt{e^{2h_t}-1}\cdot\epsilon \tag{19}

SDE-DPM-Solver-2M

(19)xt=αtαsxs2σt(eht1)ϵθ(xs,s)σt(eht1)r1(ϵθ(xr,r)ϵθ(xs,s))+σte2ht1ϵx_t=\frac{\alpha_t}{\alpha_s}x_s-2\sigma_t(e^{h_t}-1)\epsilon_\theta(x_s,s)-\frac{\sigma_t(e^{h_t}-1)}{r_1}\left(\epsilon_\theta(x_r,r)-\epsilon_\theta(x_s,s)\right)\\+\sigma_t\sqrt{e^{2h_t}-1}\cdot\epsilon \tag{19}

在 DPM-Solver++ 中作者们提出用神经网络预测干净数据 x0x_0 而不是单步添加的噪声,以此提高高阶 Solver 在条件生成时的稳定性。具体地说,神经网络现在变成了 xθ(xt,t)x0=xtσtϵαtx_\theta(x_t, t)\approx x_0 = \frac{x_t - \sigma_t\epsilon}{\alpha_t}

ODE-DPM-Solver++1

(21)xt=σtσsxsαt(eht1)xθ(xs,s)x_t=\frac{\sigma_t}{\sigma_s}x_s-\alpha_t(e^{-h_t}-1)x_\theta(x_s, s) \tag{21}

ODE-DPM-Solver++2M

(22)xt=σtσsxsαt(eht1)xθ(xs,s)αt(eht1)2r1(xθ(xr,r)xθ(xs,s))x_t=\frac{\sigma_t}{\sigma_s}x_s-\alpha_t(e^{-h_t}-1)x_\theta(x_s, s)-\frac{\alpha_t(e^{-h_t}-1)}{2r_1}\left(x_\theta(x_r, r)-x_\theta(x_s, s)\right) \tag{22}

SDE-DPM-Solver++1

(23)xt=σtσsehtxs+αt(1e2ht)xθ(xs,s)+σt1e2htϵx_t=\frac{\sigma_t}{\sigma_s}e^{-h_t}x_s+\alpha_t(1-e^{-2h_t})x_\theta(x_s, s)+\sigma_t\sqrt{1-e^{-2h_t}}\cdot\epsilon \tag{23}

SDE-DPM-Solver++2M

(24)xt=σtσsehtxs+αt(1e2ht)xθ(xs,s)αt(1e2ht)r1(xθ(xr,r)xθ(xs,s))+σt1e2htϵx_t=\frac{\sigma_t}{\sigma_s}e^{-h_t}x_s+\alpha_t(1-e^{-2h_t})x_\theta(x_s, s)-\frac{\alpha_t(1-e^{-2h_t})}{r_1}\left(x_\theta(x_r, r)-x_\theta(x_s, s)\right) \\+\sigma_t\sqrt{1-e^{-2h_t}}\cdot\epsilon \tag{24}

啊!我们又多了8种可以选择的 Solvers!

4. 现在让我们看看怎么样代码实现吧!

4.1 训练

那么根据我们之前的理论,训练就是要用一个神经网络来拟合 reverse SDE/ODE 的未知项,在本文的例子里面,就是一个 scaled score function σtxlogqt(xt)-\sigma_t\nabla_x\log q_t(x_t)。至于用什么样的神经网络,我们并不关心,它取决于数据本身的结构特性。优化的 loss 是公式(7),伪代码如下

"""
不妨假设我们的神经网络就继承自 nn.Module 吧 ~
Inputs:
    - xt:  torch.Tensor, shape=(batch_size, *x_shape)
    - t:   torch.Tensor, shape=(batch_size, )
Outputs:
    - eps: torch.Tensor, shape=(batch_size, *x_shape)
"""
nn_diffusion = MyNetwork(**kwargs)
optimizer = torch.optim.Adam(nn_diffusion.parameters(), lr=1e-4)

"""
接下来我们需要完成 noise schedule 的设计
也就是原理部分的 alpha_t, sigma_t
通常有 linear, cosine 的 schedule 方式
不论哪种,我们都假设已经设计完了,它们的形状是 (T, )
T 是 diffusion process 离散化的时间步数
是的,这里的例子是一个 Discrete 版本的模型,时间区间 [0,1] 被我们离散化成了 T 个时间步
"""
alphas = torch.tensor(...) # (T, )
sigmas = torch.tensor(...) # (T, )

# 这是一个 batch,从数据集取出来的,但这里我们用随机生成的数据代替
x0 = torch.randn(batch_size, *x_shape)

# 开始更新噜
idx = torch.randint(0, T, (batch_size, ))
alpha, sigma = alphas[idx], sigmas[idx]

eps = torch.randn_like(x0)
xt = alpha * x0 + sigma * eps # 见公式(1)

pred_eps = nn_diffusion(xt, idx) # 因为离散化了,所以idx和t是等价的

loss = ((pred_eps - eps) ** 2).mean() # 见公式(7)

optimizer.zero_grad()
loss.backward()
optimizer.step()

print(f'Loss: {loss.item():.4f}')

OK,简单地把数据来源换成你的数据集,示例的 one-step 更新换成你的训练循环,就可以开始训练了!这里面 nn_diffusion 就是原理部分里面的 ϵθ\epsilon_\theta

4.2 采样

训练完了之后,我们就可以开始采样了!根据我们之前的理论,我们可以随意在 DDPM/DDIM/DPM-Solver 间切换,这里我们就给出一个简单的单步采样伪代码:

"""
假设我们想从时间步 s 采样到时间步 t,此刻是第 i 个 sampling step
"""
alpha_s, sigma_s = alphas[s], sigmas[s]
alpha_t, sigma_t = alphas[t], sigmas[t]
lambda_s, lambda_t = torch.log(alpha_s / sigma_s), torch.log(alpha_t / sigma_t)
h_t = lambda_t - lambda_s
c_s = sigma_t / sigma_s * torch.sqrt(1 - alpha_s ** 2 / alpha_t ** 2)

eps_theta = nn_diffusion(x_s, s) # \epsilon_\theta(x_s, s)
eps = torch.randn_like(x_s)

# 使用高阶 DPM-Solver 需要保存一些历史
buffer = []

if solver == "ddpm":

    x_t = alpha_t * (x_s - sigma_s * eps_theta) / alpha_s + torch.sqrt(sigma_t ** 2 - c_s ** 2) * eps_theta
    # DDPM 在最后一步不加噪声哦
    if i < sampling_steps - 1:
        x_t += c_s * eps

elif solver == "ddim":

    x_t = alpha_t * (x_s - sigma_s * eps_theta) / alpha_s + sigma_t * eps_theta

elif solver == "ode_dpmsolver_1":

    # 小技巧,用 torch.expm1 代替 torch.exp() - 1 可以缓解数值不稳定
    x_t = alpha_t / alpha_s * x_s - sigma_t * (torch.expm1(h_t)) * eps_theta

elif solver == "ode_dpmsolver_2M":

    buffer.append(eps_theta)

    if i > 0:
        r = h_s / h_t
        D = (1 + 0.5 / r) * buffer[-1] - 0.5 / r * buffer[-2]
        x_t = (alpha_t / alpha_s) * x_s - sigma_t * torch.expm1(h_t) * D
    else:
        x_t = (alpha_t / alpha_s) * x_s - sigma_t * torch.expm1(h_t) * eps_theta

elif solver == "sde_dpmsolver_1":

    x_t = alpha_t / alpha_s * x_s - 2 * sigma_t * (torch.expm1(h_t)) * eps_theta + sigma_t * torch.sqrt(torch.expm1(2 * h_t)) * eps

elif solver == "sde_dpmsolver_2M":

    buffer.append(eps_theta)

    if i > 0:
        r = h_s / h_t
        D = (1 + 0.5 / r) * buffer[-1] - 0.5 / r * buffer[-2]
        x_t = (alpha_t / alpha_s) * x_s - 2 * sigma_t * torch.expm1(h_t) * D + sigma_t * torch.sqrt(torch.expm1(2 * h_t)) * eps
    else:
        x_t = (alpha_t / alpha_s) * x_s - 2 * sigma_t * torch.expm1(h_t) * eps_theta + sigma_t * torch.sqrt(torch.expm1(2 * h_t)) * eps

elif solver == "ode_dpmsolver++_1":

    x_t = sigma_t / sigma_s * x_s - alpha_t * (torch.expm1(-h_t)) * x_theta(x_s, s)

elif solver == "ode_dpmsolver++_2M":
    
        buffer.append(x_theta(x_s, s))
    
        if i > 0:
            r = h_s / h_t
            D = (1 + 0.5 / r) * buffer[-1] - 0.5 / r * buffer[-2]
            x_t = (sigma_t / sigma_s) * x_s - alpha_t * torch.expm1(-h_t) * D
        else:
            x_t = (sigma_t / sigma_s) * x_s - alpha_t * torch.expm1(-h_t) * x_theta(x_s, s)
    
elif solver == "sde_dpmsolver++_1":

    x_t = (sigma_t / sigma_s) * torch.exp(-h_t) * x_s + alpha_t * (1 - torch.exp(-2 * h_t)) * x_theta(x_s, s) + sigma_t * torch.sqrt(1 - torch.exp(-2 * h_t)) * eps

elif solver == "sde_dpmsolver++_2M":

    buffer.append(x_theta(x_s, s))

    if i > 0:
        r = h_s / h_t
        D = (1 + 0.5 / r) * buffer[-1] - 0.5 / r * buffer[-2]
        x_t = (sigma_t / sigma_s) * torch.exp(-h_t) * x_s + alpha_t * (1 - torch.exp(-2 * h_t)) * D + sigma_t * torch.sqrt(1 - torch.exp(-2 * h_t)) * eps
    else:
        x_t = (sigma_t / sigma_s) * torch.exp(-h_t) * x_s + alpha_t * (1 - torch.exp(-2 * h_t)) * x_theta(x_s, s) + sigma_t * torch.sqrt(1 - torch.exp(-2 * h_t)) * eps

else:
    raise NotImplementedError(f"Solver {solver} not implemented!")

# 现在我们就有了 x_t 咯 ~

哇!太爽了,10种 Solver 随便切换!只要训练一次!爽!

5. 总结

在这篇文章中,我们回顾了如何统一 DDPM/DDIM/DPM-Solver 的理论基础,以及如何在代码实现中随意切换这些 Solver。这种设计不仅让我们的代码更加灵活,也让我们的分析更加方便。希望这篇文章能帮助到你。如果有任何问题,欢迎在评论区留言,我会尽力解答。如果你觉得自己动手实现这些太麻烦了,并且你刚好是想用 Diffusion Models 做决策,那不如试试试试 CleanDiffuser 吧!这些功能都内置了,并且更灵活,更宽泛,更强大!最后,感谢你的阅读!

本文使用 Zhihu On VSCode 创作并发布
发布于 2024-06-20 19:26・IP 属地天津 ,编辑于 2024-06-20 19:37・IP 属地天津
欢迎参与讨论

6 条评论
默认
最新
anArkitek

请问 公式 (3) 有证明吗?

4 小时前 · IP 属地美国
董子斌
作者
有。如果看论文的话可以看song yang的score based generative models那篇paper,博客可以看苏剑林老师的生成扩散模型漫谈的SDE部分。证明大概的思路就是把(1)写成(2)的递推形式,就有f,g和alpha sigma的关系了。
3 小时前 · IP 属地天津
zzzzz
我也想加入TJU,dzb大佬带带[可怜]
9 小时前 · IP 属地新加坡
董子斌
作者
呜呜呜🧎🏻‍♂️
9 小时前 · IP 属地天津
雅痞
我也开始接触diffusion了,大佬带带
昨天 19:28 · IP 属地北京
董子斌
作者
你需要我带??
昨天 19:35 · IP 属地天津
想来知乎工作?请发送邮件到 jobs@zhihu.com