CleanDiffuser(二)一起来写一个Cross-solver的Diffusion Model!
1. 什么是Cross-solver的Diffusion Model?
在一些同学的概念里DDPM/DDIM/DPM-Solvers等等都是“独立”的工作,它们之间并没有直接的联系。实际上这些模型都能统一到一个框架,让我只需要训练一个模型,就可以在推理时随意切换不同的solver,以及不同的sampling steps。这样的代码实现能让我们的分析十分方便,也能让我们的模型更加灵活。CleanDiffuser中的模型都支持这个功能,在这篇文章里我们一起来动手实现一下叭!
2. 理论基础:我们该如何统一
设想我们想用 Diffusion Model 生成目标分布 q_0(x) 的样本。DiffusionSDE 通过 noise schedule \{\alpha_t, \sigma_t\}_{t\in[0,T]} 定义了前向过程 \{x_t\}_{t\in[0,T]},任意时刻 \forall t\in[0,T] 的 x_t 都服从:
x_t=\alpha_t x_0 + \sigma_t \epsilon, ~\epsilon\in\mathcal N(0,I) \tag{1}
这个过程可以表达成等价的随机微分方程 SDE,其中 w_t 是标准维纳过程:
{\rm d}x_t = f(t) x_t {\rm d}t + g(t){\rm d}w_t, ~x_0\sim q_0(x_0) \tag{2}
怎么理解这个等价? 从 x_0 开始求解(2)到 t 时刻得到的 x_t,和直接用(1)对 x_0 加噪得到的 x_t 分布是相同的。因为他们是等价的,所以显然 f,g 和 \alpha,\sigma 之间一一对应。
f(t)=\frac{{\rm d}\log\alpha_t}{{\rm d}t},~g^2(t)=\frac{{\rm d}\sigma_t^2}{{\rm d}t}-2\sigma_t^2\frac{{\rm d}\log\alpha_t}{{\rm d}t} \tag{3}
这里面的 f,g,\alpha,\sigma 都是我们自己定义的已知量。既然两者是等价的,我们为什么要关注SDE形式呢?因为我们关心不是前向过程,而是逆向过程,我们希望从 q_T(x_T) (近似高斯噪声) 逆向生成 q_0(x_0) 的样本。DiffusionSDE 具有 SDE/ODE 两种逆向过程形式,它们具有相同的边际分布:
{\rm d}x_t=\left[f(t)x_t-g^2(t)\nabla_x\log q_t(x_t)\right]{\rm d}t+g(t){\rm d}\tilde w_t \tag{4}
{\rm d}x_t=\left[f(t)x_t-\frac{1}{2}g^2(t)\nabla_x\log q_t(x_t)\right]{\rm d}t \tag{5}
通过求解(4)或(5)我们就能从 q_T(x_T) 样本逆向生成 q_0(x_0) 样本。可惜这个微分方程里面存在一个未知项 \nabla_x\log q_t(x_t)(称为 score function),所以秉承深度学习的精神,我们用神经网络来近似这个项,我们不妨先看看 \nabla_x\log q_t(x_t|x_0) 是什么样子的:
\nabla_x\log q_t(x_t|x_0)=\nabla_x\log\left[C\exp\left(-\frac{(x_t-\alpha_t x_0)^2}{2\sigma_t^2}\right)\right]=-\frac{1}{\sigma_t^2}(x_t-\alpha_t x_0)=-\frac{\epsilon_t}{\sigma_t} \tag{6}
其中 C 是和 x 无关的常量。通过最小化 \mathbb E_{q_t(x_t|x_0)}\left[d(s_\theta(x_t,t),\nabla_x\log q_t(x_t|x_0) )\right] 就能让神经网络 s_\theta(x_t,t) 近似 \nabla_x\log q_t(x_t),d 是随便某种距离度量。在实践中一般不会用神经网络直接拟合 score function,而是用神经网络拟合单步添加的噪声,或直接预测“干净”数据:
\epsilon_\theta(x_t, t) \approx -\sigma_t\nabla_x\log q_t(x_t)=\epsilon_t \\\mathcal L(\theta)=\mathbb E_{q_t(x_t|x_0)}\left[d(\epsilon_\theta(x_t,t),\epsilon_t)\right] \tag{7}
x_\theta(x_t, t) \approx \alpha_t x_t-\sigma^2_t\nabla_x\log q_t(x_t)=\alpha_t x_t+\sigma_t\epsilon_t=x_0 \\\mathcal L(\theta)=\mathbb E_{q_t(x_t|x_0)}\left[d(x_\theta(x_t,t),x_0)\right] \tag{8}
关于神经网络训练的小结: 无论如何我们都是只是用神经网络估计逆向 SDE/ODE 中的未知项,在 DiffusionSDE 里面就是此处的 score function \nabla_x\log q_t(x_t),不同的 Loss 设计也不过就是预测 score function 加个或乘个已知项。完成训练后我们就可以开始对 SDE/ODE 进行求解了,换句话说,目前为止我们还从未开始尝试求解,跟 Solver 都还没有关系! 看下面的三个例子,我们这样完成训练后就可以随便在 Solver 间切换了。
如果我们用的是 DiffusionSDE(VP-SDE),那么神经网络要拟合的未知项就是 scaled score function。将 reverse SDE 一阶离散用 Euler 求解就是 DDPM,将 reverse ODE 一阶离散用 Euler 求解就是 DDIM,将 reverse SDE/ODE 的线性部分求精确解,非线性部分取泰勒展开k阶估计就是k阶 DPM-Solver。训练一个神经网络我们就可以在这些 Solver 里面随意切换。
如果我们用的是 EDM,那么神经网络要拟合的未知项就是“干净”的数据(不管 preconditioning 的话)。将 reverse ODE 用 Euler 求解就是 EDM原文的 Euler,用 2nd order Heun 求解,就是 EDM 默认的 Heun。训练一个神经网络我们就可以在这些 Solver 里面随意切换。
如果我们用的是 Rectified flow,那么神经网络要拟合的未知项就是 Drift force。将 reverse ODE 用 Euler 求解就是 Rectified flow 的标准复现,我们当然也可以用 RK45/Heun 等等求解器。训练一个神经网络我们就可以在这些 Solver 里面随意切换。
当然,在这篇文章中我们主要关注的是 DiffusionSDE,也就是前文一直介绍的这种 formulation,所以我们假设已经训练好了一个神经网络 \epsilon_\theta(x_t,t)\approx-\sigma_t\nabla_x\log q_t(x_t),我们来看看在采样的时候是怎么变成 DDPM/DDIM/DPM-Solver的。
3. 在各种 Solvers 间穿梭自如
3.1 DDPM & DDIM
我们先将 reverse ODE 做一阶离散化,假设从时刻 \tau=s 开始,我们希望求解 \tau=t 时刻的 x_t:
\begin{align}x_t-x_s&=\left[f(s)x_s+\frac{g^2(s)}{2\sigma_s}\epsilon_\theta(x_s,s)\right]\Delta t \tag{9} \\x_t-x_s&=\left[\left(\frac{\alpha_t}{\alpha_s}-1\right)\cdot x_s+\alpha_t\left(\frac{\sigma_t}{\alpha_t}-\frac{\sigma_s}{\alpha_s}\right)\cdot\epsilon_\theta(x_s,s)\right] \tag{10} \\x_t &= \frac{\alpha_t}{\alpha_s}\cdot x_s+\alpha_t\left(\frac{\sigma_t}{\alpha_t}-\frac{\sigma_s}{\alpha_s}\right)\cdot\epsilon_\theta(x_s,s) \tag{11}\end{align}
我们再看看 DDIM 的递推解噪过程,注意因为 DDIM 论文里和本文的符号系统有冲突,DDIM 中的 \alpha 其实等价于本文的 \alpha^2,所以我们统一用 a 来表示 DDIM 中的 \alpha:
\begin{align}x_t&=\sqrt{a_t}\left(\frac{x_s-\sqrt{1-a_s}\cdot\epsilon_\theta(x_s)}{\sqrt{a_s}}\right)+\sqrt{1-a_t}\cdot\epsilon_\theta(x_t) \tag{12} \\x_t&=\alpha_t\left(\frac{x_s-\sigma_s\cdot\epsilon_\theta(x_s)}{\alpha_s}\right)+\sigma_t\cdot\epsilon_\theta(x_t) \tag{13} \\x_t &= \frac{\alpha_t}{\alpha_s}\cdot x_s+\alpha_t\left(\frac{\sigma_t}{\alpha_t}-\frac{\sigma_s}{\alpha_s}\right)\cdot\epsilon_\theta(x_s,s) \tag{14} \\\end{align}
可以看到是完全一样的,DDIM 就是在用 \Delta\alpha/\Delta t 估计 {\rm d}\alpha/{\rm d}t,然后等式两边都乘上 \Delta t 就变成递推式了。所有有兴趣的同学也可以尝试一下直接计算出 {\rm d}\alpha/{\rm d}t 再显式乘上 \Delta t,EDM 就属于这种做法。
同理,要得到 DDPM,只需要在 DDIM 的基础上补一个:
c_s=\sqrt{\frac{1-a_{t}}{1-a_s}}\cdot\sqrt{1-\frac{a_s}{a_t}}=\frac{\sigma_t}{\sigma_s}\sqrt{1-\frac{\alpha^2_s}{\alpha^2_t}} \tag{15}
使用:
x_t=\alpha_t\left(\frac{x_s-\sigma_s\cdot\epsilon_\theta(x_s)}{\alpha_s}\right)+\sqrt{\sigma^2_t-c^2_s}\cdot\epsilon_\theta(x_t)+c_s\epsilon \tag{16}
就能得到 DDPM 的递推式了。这里的 \epsilon 是标准高斯噪声。
3.2 DPM-Solvers
(如果你对这些内容感到陌生且不懂,不如来康康之前的 DPM-Solver 的文章吧 80岁的公园大爷问我 DPMsolver 是什么然后我给他写了这个他说哦我看懂了 - 知乎 (zhihu.com))
DPM-Solver 注意到了 reverse ODE 的半线性,所以直接计算出了从 \tau=s 到 \tau=t 的精确解,显然,其中的线性部分是可以直接计算的,而非线性部分能写出来,但算不了:
x_t = \frac{\alpha_t}{\alpha_s}x_s-\alpha_t\int_s^t\frac{{\rm d}\lambda}{{\rm d}\tau}\frac{\sigma_\tau}{\alpha_\tau}\epsilon_\theta(x_\tau,\tau){\rm d}\tau \tag{17}
其中 \lambda=\log(\alpha/\sigma) 是 log信噪比。因此非线性部分可以用泰勒展开来估计,这就是 DPM-Solver 的核心思想。我们对非线性项进行展开,得到:
x_t = \frac{\alpha_t}{\alpha_s}x_s-\alpha_t\sum_{n=0}^{k-1}\epsilon_\theta^{(n)}(x_s,s)\int_{\lambda_s}^{\lambda_t}e^{-\lambda}\frac{(\lambda-\lambda_s)^n}{n!}{\rm d}\lambda+\mathcal O((\lambda_t-\lambda_s)^{k+1}) \tag{18}
这里虽然是一个一般形式,但是实际用的时候不会用很高阶的,因为会有不稳定的问题。我们直接来看看一阶和二阶的递推式,先是 ODE 的一阶形式,h_t=\lambda_t-\lambda_s:
ODE-DPM-Solver-1
x_t=\frac{\alpha_t}{\alpha_s}x_s-\sigma_t(e^{h_t}-1)\epsilon_\theta(x_s,s) \tag{19}
把 \lambda 带进去替换掉会发现和(14)是完全一样的,所以 DPM-Solver 的一阶递推式就是 DDIM 的递推式。二阶需要计算神经网络的一阶导,这个事情比较困难,但是 DPM-Solver 证明了可以用两次的直接计算结果来估计这个一阶导的结果,这里就不展开,直接给出二阶递推式,假设 t < r < s 且记 r_1=\frac{\lambda_r-\lambda_s}{h_t}:
ODE-DPM-Solver-2M
x_t=\frac{\alpha_t}{\alpha_s}x_s-\sigma_t(e^{h_t}-1)\epsilon_\theta(x_s,s)-\frac{\sigma_t(e^{h_t}-1)}{2r_1}\left(\epsilon_\theta(x_r,r)-\epsilon_\theta(x_s,s)\right) \tag{20}
那么很显然,对于 SDE,我们完全可以如法炮制地玩一遍,同样得到一阶和二阶的递推式:
SDE-DPM-Solver-1
x_t=\frac{\alpha_t}{\alpha_s}x_s-2\sigma_t(e^{h_t}-1)\epsilon_\theta(x_s,s)+\sigma_t\sqrt{e^{2h_t}-1}\cdot\epsilon \tag{19}
SDE-DPM-Solver-2M
x_t=\frac{\alpha_t}{\alpha_s}x_s-2\sigma_t(e^{h_t}-1)\epsilon_\theta(x_s,s)-\frac{\sigma_t(e^{h_t}-1)}{r_1}\left(\epsilon_\theta(x_r,r)-\epsilon_\theta(x_s,s)\right)\\+\sigma_t\sqrt{e^{2h_t}-1}\cdot\epsilon \tag{19}
在 DPM-Solver++ 中作者们提出用神经网络预测干净数据 x_0 而不是单步添加的噪声,以此提高高阶 Solver 在条件生成时的稳定性。具体地说,神经网络现在变成了 x_\theta(x_t, t)\approx x_0 = \frac{x_t - \sigma_t\epsilon}{\alpha_t}。
ODE-DPM-Solver++1
x_t=\frac{\sigma_t}{\sigma_s}x_s-\alpha_t(e^{-h_t}-1)x_\theta(x_s, s) \tag{21}
ODE-DPM-Solver++2M
x_t=\frac{\sigma_t}{\sigma_s}x_s-\alpha_t(e^{-h_t}-1)x_\theta(x_s, s)-\frac{\alpha_t(e^{-h_t}-1)}{2r_1}\left(x_\theta(x_r, r)-x_\theta(x_s, s)\right) \tag{22}
SDE-DPM-Solver++1
x_t=\frac{\sigma_t}{\sigma_s}e^{-h_t}x_s+\alpha_t(1-e^{-2h_t})x_\theta(x_s, s)+\sigma_t\sqrt{1-e^{-2h_t}}\cdot\epsilon \tag{23}
SDE-DPM-Solver++2M
x_t=\frac{\sigma_t}{\sigma_s}e^{-h_t}x_s+\alpha_t(1-e^{-2h_t})x_\theta(x_s, s)-\frac{\alpha_t(1-e^{-2h_t})}{r_1}\left(x_\theta(x_r, r)-x_\theta(x_s, s)\right) \\+\sigma_t\sqrt{1-e^{-2h_t}}\cdot\epsilon \tag{24}
啊!我们又多了8种可以选择的 Solvers!
4. 现在让我们看看怎么样代码实现吧!
4.1 训练
那么根据我们之前的理论,训练就是要用一个神经网络来拟合 reverse SDE/ODE 的未知项,在本文的例子里面,就是一个 scaled score function -\sigma_t\nabla_x\log q_t(x_t)。至于用什么样的神经网络,我们并不关心,它取决于数据本身的结构特性。优化的 loss 是公式(7),伪代码如下
"""
不妨假设我们的神经网络就继承自 nn.Module 吧 ~
Inputs:
- xt: torch.Tensor, shape=(batch_size, *x_shape)
- t: torch.Tensor, shape=(batch_size, )
Outputs:
- eps: torch.Tensor, shape=(batch_size, *x_shape)
"""
nn_diffusion = MyNetwork(**kwargs)
optimizer = torch.optim.Adam(nn_diffusion.parameters(), lr=1e-4)
"""
接下来我们需要完成 noise schedule 的设计
也就是原理部分的 alpha_t, sigma_t
通常有 linear, cosine 的 schedule 方式
不论哪种,我们都假设已经设计完了,它们的形状是 (T, )
T 是 diffusion process 离散化的时间步数
是的,这里的例子是一个 Discrete 版本的模型,时间区间 [0,1] 被我们离散化成了 T 个时间步
"""
alphas = torch.tensor(...) # (T, )
sigmas = torch.tensor(...) # (T, )
# 这是一个 batch,从数据集取出来的,但这里我们用随机生成的数据代替
x0 = torch.randn(batch_size, *x_shape)
# 开始更新噜
idx = torch.randint(0, T, (batch_size, ))
alpha, sigma = alphas[idx], sigmas[idx]
eps = torch.randn_like(x0)
xt = alpha * x0 + sigma * eps # 见公式(1)
pred_eps = nn_diffusion(xt, idx) # 因为离散化了,所以idx和t是等价的
loss = ((pred_eps - eps) ** 2).mean() # 见公式(7)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f'Loss: {loss.item():.4f}')
OK,简单地把数据来源换成你的数据集,示例的 one-step 更新换成你的训练循环,就可以开始训练了!这里面 nn_diffusion
就是原理部分里面的 \epsilon_\theta。
4.2 采样
训练完了之后,我们就可以开始采样了!根据我们之前的理论,我们可以随意在 DDPM/DDIM/DPM-Solver 间切换,这里我们就给出一个简单的单步采样伪代码:
"""
假设我们想从时间步 s 采样到时间步 t,此刻是第 i 个 sampling step
"""
alpha_s, sigma_s = alphas[s], sigmas[s]
alpha_t, sigma_t = alphas[t], sigmas[t]
lambda_s, lambda_t = torch.log(alpha_s / sigma_s), torch.log(alpha_t / sigma_t)
h_t = lambda_t - lambda_s
c_s = sigma_t / sigma_s * torch.sqrt(1 - alpha_s ** 2 / alpha_t ** 2)
eps_theta = nn_diffusion(x_s, s) # \epsilon_\theta(x_s, s)
eps = torch.randn_like(x_s)
# 使用高阶 DPM-Solver 需要保存一些历史
buffer = []
if solver == "ddpm":
x_t = alpha_t * (x_s - sigma_s * eps_theta) / alpha_s + torch.sqrt(sigma_t ** 2 - c_s ** 2) * eps_theta
# DDPM 在最后一步不加噪声哦
if i < sampling_steps - 1:
x_t += c_s * eps
elif solver == "ddim":
x_t = alpha_t * (x_s - sigma_s * eps_theta) / alpha_s + sigma_t * eps_theta
elif solver == "ode_dpmsolver_1":
# 小技巧,用 torch.expm1 代替 torch.exp() - 1 可以缓解数值不稳定
x_t = alpha_t / alpha_s * x_s - sigma_t * (torch.expm1(h_t)) * eps_theta
elif solver == "ode_dpmsolver_2M":
buffer.append(eps_theta)
if i > 0:
r = h_s / h_t
D = (1 + 0.5 / r) * buffer[-1] - 0.5 / r * buffer[-2]
x_t = (alpha_t / alpha_s) * x_s - sigma_t * torch.expm1(h_t) * D
else:
x_t = (alpha_t / alpha_s) * x_s - sigma_t * torch.expm1(h_t) * eps_theta
elif solver == "sde_dpmsolver_1":
x_t = alpha_t / alpha_s * x_s - 2 * sigma_t * (torch.expm1(h_t)) * eps_theta + sigma_t * torch.sqrt(torch.expm1(2 * h_t)) * eps
elif solver == "sde_dpmsolver_2M":
buffer.append(eps_theta)
if i > 0:
r = h_s / h_t
D = (1 + 0.5 / r) * buffer[-1] - 0.5 / r * buffer[-2]
x_t = (alpha_t / alpha_s) * x_s - 2 * sigma_t * torch.expm1(h_t) * D + sigma_t * torch.sqrt(torch.expm1(2 * h_t)) * eps
else:
x_t = (alpha_t / alpha_s) * x_s - 2 * sigma_t * torch.expm1(h_t) * eps_theta + sigma_t * torch.sqrt(torch.expm1(2 * h_t)) * eps
elif solver == "ode_dpmsolver++_1":
x_t = sigma_t / sigma_s * x_s - alpha_t * (torch.expm1(-h_t)) * x_theta(x_s, s)
elif solver == "ode_dpmsolver++_2M":
buffer.append(x_theta(x_s, s))
if i > 0:
r = h_s / h_t
D = (1 + 0.5 / r) * buffer[-1] - 0.5 / r * buffer[-2]
x_t = (sigma_t / sigma_s) * x_s - alpha_t * torch.expm1(-h_t) * D
else:
x_t = (sigma_t / sigma_s) * x_s - alpha_t * torch.expm1(-h_t) * x_theta(x_s, s)
elif solver == "sde_dpmsolver++_1":
x_t = (sigma_t / sigma_s) * torch.exp(-h_t) * x_s + alpha_t * (1 - torch.exp(-2 * h_t)) * x_theta(x_s, s) + sigma_t * torch.sqrt(1 - torch.exp(-2 * h_t)) * eps
elif solver == "sde_dpmsolver++_2M":
buffer.append(x_theta(x_s, s))
if i > 0:
r = h_s / h_t
D = (1 + 0.5 / r) * buffer[-1] - 0.5 / r * buffer[-2]
x_t = (sigma_t / sigma_s) * torch.exp(-h_t) * x_s + alpha_t * (1 - torch.exp(-2 * h_t)) * D + sigma_t * torch.sqrt(1 - torch.exp(-2 * h_t)) * eps
else:
x_t = (sigma_t / sigma_s) * torch.exp(-h_t) * x_s + alpha_t * (1 - torch.exp(-2 * h_t)) * x_theta(x_s, s) + sigma_t * torch.sqrt(1 - torch.exp(-2 * h_t)) * eps
else:
raise NotImplementedError(f"Solver {solver} not implemented!")
# 现在我们就有了 x_t 咯 ~
哇!太爽了,10种 Solver 随便切换!只要训练一次!爽!
5. 总结
在这篇文章中,我们回顾了如何统一 DDPM/DDIM/DPM-Solver 的理论基础,以及如何在代码实现中随意切换这些 Solver。这种设计不仅让我们的代码更加灵活,也让我们的分析更加方便。希望这篇文章能帮助到你。如果有任何问题,欢迎在评论区留言,我会尽力解答。如果你觉得自己动手实现这些太麻烦了,并且你刚好是想用 Diffusion Models 做决策,那不如试试试试 CleanDiffuser 吧!这些功能都内置了,并且更灵活,更宽泛,更强大!最后,感谢你的阅读!
本文使用 Zhihu On VSCode 创作并发布
请问 公式 (3) 有证明吗?