因果扩散自动编码器:通过扩散概率模型实现反事实生成
Aneesh Komanduri ^(1,**){ }^{1, *} ,陈兆 ^(2){ }^{2} ,陈锋 ^(3){ }^{3} 和 肖恩涛 Wu^(1)\mathbf{W u}^{1}
^(1){ }^{1} 阿肯色大学 贝勒大学
德克萨斯大学达拉斯分校
摘要
扩散概率模型(DPMs)已成为高质量图像生成的最先进技术。然而,DPMs 具有任意噪声的潜在空间,没有可解释或可控的语义。尽管已经进行了大量研究来提高图像样本质量,但关于使用扩散模型进行表示控制的生成工作很少。具体来说,使用 DPMs 进行因果建模和可控反事实生成是一个尚未充分探索的领域。在这项工作中,我们提出了因果扩散自动编码器(CausalDiffAE),这是一个基于扩散的因果表示学习框架,可以根据指定的因果模型实现反事实生成。我们的关键思想是使用编码器从高维数据中提取高级语义上有意义的因果变量,并使用反向扩散来建模随机变化。我们提出了一种因果编码机制,将高维数据映射到因果相关的潜在因子,并使用神经网络参数化潜在因子之间的因果机制。为了强制因果变量的解耦,我们制定了一个变分目标,并在先验中利用辅助标签信息来正则化潜在空间。 我们提出了一种基于 DDIM 的受干预的逆事实生成过程。最后,为了解决有限的标签监督场景,我们还研究了当部分训练数据未标记时 CausalDiffAE 的应用,这也使得在推理过程中生成逆事实时能够对干预的强度进行细粒度控制。我们通过实证表明,CausalDiffAE 学习到一个解耦的潜在空间,并且能够生成高质量的逆事实图像。
1 引言
扩散概率模型(DPMs)[31, 11, 20, 32, 33] 是一类基于似然函数的生成模型,在生成高分辨率图像方面取得了显著的成功,例如 DALLE-2 [26]、Stable Diffusion [27] 和 Imagen [28] 等大规模实现。因此,对评估扩散模型的能力产生了极大的兴趣。其中两种最有前景的方法被表述为数据分布的离散时间 [11] 和连续时间 [33] 步进扰动。然后训练一个模型来估计逆向过程,该过程将噪声样本转换为来自潜在数据分布的样本。表示学习一直是 GANs [8] 和 VAEs [14] 等生成模型的一个基本组成部分,用于从复杂数据中提取鲁棒和可解释的特征 [30,2,25]。最近,研究
DPMs 能否被用来提取具有语义意义且可解码的表示,从而提高生成图像的质量和控制[21, 24]。然而,目前还没有关于在 DPMs 中建模语义潜在代码之间的因果关系,以学习因果关系表示并在推理时实现反事实生成的相关工作。在医疗保健和医学等领域,生成高质量的反事实图像至关重要[17, 29]。从领域知识中获得的因果图中生成准确的反事实数据可以显著降低数据收集的成本。此外,对训练分布中未见过的假设情景进行推理,对于评估复杂系统中因果变量之间的相互作用具有很大的洞察力。给定一个系统的因果图,我们研究了 DPMs 作为因果关系学习者的能力,并评估了它们在干预因果变量时生成反事实的能力。
直观上,我们可以将 DPM 视为一个编码器-解码器框架。编码器通过一系列高斯噪声扰动将输入图像 x_(0)\mathbf{x}_{0} 映射到一个空间潜在变量 x_(T)\mathbf{x}_{T} 。然而, x_(T)\mathbf{x}_{T} 可以解释为一个缺乏高级语义的噪声表示[24]。最近,Preechakul 等人[24]提出了一种基于扩散的自编码器(DiffAE),用于提取与可解码表示学习相关的随机低级表示 x_(T)\mathbf{x}_{T} 以及高级语义表示。学习这样的语义表示也使得在潜在空间中进行插值成为可能,以实现可控生成,并且已被证明可以提高图像生成质量。Mittal 等人[19]在此基础上构建了一个框架,并引入了一个基于扩散的表示学习(DRL)目标,该目标在整个扩散过程中学习时间条件下的表示。然而,这两种方法都学习任意表示,并且不关注解耦,这是可解释表示的关键属性。解耦表示能够实现对生成因素的精确控制。 当考虑因果系统时,解耦对于执行隔离干预至关重要。
在这篇论文中,我们专注于学习可分离的因果表示,其中高级语义因素是因果相关的。据我们所知,我们是第一个探索使用扩散概率模型进行基于表示的反事实图像生成的。我们提出了因果扩散自动编码器(CausalDiffAE),这是一种用于在 DPM 中进行因果表示学习和可控反事实生成的学习框架。我们的关键思想是通过可学习的随机编码器学习因果表示,并通过由神经网络参数化的因果机制来建模潜在变量之间的关系。我们通过一个标签对齐先验来制定一个变分目标,以强制实现解耦。
学习因果因素的纠缠。然后,我们利用条件去噪扩散隐式模型(DDIM)[32]进行解码和建模随机变化。直观上,因果表示编码了与反向扩散中图像解码相关的紧凑信息。此外,在潜在空间中对因果关系的建模使得在干预学习到的因果变量时能够生成反事实。我们提出了一种 DDIM 变体,用于在干预 (*)(\cdot) 的情况下进行反事实生成。为了提高模型的实用性和可解释性,我们提出了 CausalDiffAE 的扩展,该扩展利用了较弱的监督。在标记数据有限的情况下,我们分别对未标记和标记分区联合训练一个无条件和表示条件的扩散模型。这种方法显著减少了训练所需的标记样本数量,并能够对干预的强度和生成的反事实的质量进行细粒度控制。
近期因果生成建模的研究主要集中在学习因果表示或可控反事实生成[16]。Yang 等人提出了 CausalVAE[35],这是一个通过线性结构因果模型(SCM)来建模潜在因果变量的因果表示学习框架。Kocaoglu 等人[15]提出了 CausalGAN,这是 GAN 的一个扩展,用于建模因果变量以从干预分布中进行采样。扩散和基于分数的生成模型[11, 33]在类条件生成方面已经显示出令人印象深刻的结果,无论是基于分类器[5]还是无分类器[10]的方法。最近,人们开始对探索扩散模型作为表示学习者的能力产生了兴趣。例如,Mittal 等人[19]和 Preechakul 等人[24]考虑了基于扩散的表示学习目标。Mamaghan 等人[12]在能够访问反事实对形式的数据的情况下,从基于分数的角度探讨了表示学习。然而,这项工作并不专注于反事实生成。 另一个相关的研究领域是反事实解释[1],它侧重于生成现实反事实的后处理方法,但不是严格意义上的因果关系。我们的工作侧重于基于扩散的表示学习,与 DiffAE[24]和 DRL[19]最为接近,它们旨在学习语义上有意义的表示。然而,关键的区别在于我们学习因果关系表示以实现反事实生成。我们提出的框架将 CausalVAE 扩展到基于扩散的模型,并在较弱的监督范式下进行。
3 背景
3.1 结构因果模型
一个结构因果模型(SCM)由一个元组形式正式定义,其中 Z\mathcal{Z} 是 nn 内生因果变量集合的域, z={z_(1),dots,z_(n)},U\mathbf{z}=\left\{z_{1}, \ldots, z_{n}\right\}, \mathcal{U} 是 nn 外生噪声变量集合的域, u={u_(1),dots,u_(n)}\mathbf{u}=\left\{u_{1}, \ldots, u_{n}\right\} 作为中间潜在变量学习,而 F={f_(1),dots,f_(n)}F=\left\{f_{1}, \ldots, f_{n}\right\} 是一组 nn 独立因果机制
z_(i)=f_(i)(u_(i),z_(pa_(i)))z_{i}=f_{i}\left(u_{i}, z_{\mathbf{p a}_{i}}\right)
其中 AA i,f_(i):U_(i)xxprod_(j inpa_(i))Z_(j)rarrZ_(i)\forall i, f_{i}: \mathcal{U}_{i} \times \prod_{j \in \mathbf{p a}_{i}} \mathcal{Z}_{j} \rightarrow \mathcal{Z}_{i} 是决定每个因果变量作为父母和噪声函数的因果机制, z_(pa_(i))z_{\mathbf{p a}_{i}} 是因果变量 z_(i)z_{i} 的父母;以及一个概率测度 p_(U)(u)=p_(U_(1))(u_(1))p_(U_(2))(u_(2))dotsp_(U_(n))(u_(n))p_{\mathcal{U}}(\mathbf{u})=p_{\mathcal{U}_{1}}\left(u_{1}\right) p_{\mathcal{U}_{2}}\left(u_{2}\right) \ldots p_{\mathcal{U}_{n}}\left(u_{n}\right) ,它允许乘积分布。为了本工作的目的,我们假设因果机制
充分设置(无隐藏混杂因素),无 SCM 误设,并且忠实性得到满足。
3.2 扩散概率模型
扩散概率模型(DPMs)[11, 20]在图像生成任务中表现出令人印象深刻的结果,甚至在许多情况下超过了 GANs[5]。去噪扩散概率模型(DDPM)[11]的思路是定义一个马尔可夫链的扩散步骤,通过添加噪声[11]通过正向扩散过程缓慢破坏数据分布中的结构,并学习一个反向扩散过程来恢复数据的结构。一些提出的方法,如去噪扩散隐式模型(DDIM)[32],通过执行噪声的确定性编码来打破马尔可夫假设,从而加快扩散过程中的采样速度。
正向扩散。给定从分布 x_(0)∼q(x)\mathbf{x}_{0} \sim q(\mathbf{x}) 中采样的某些输入数据,正向扩散过程通过在 TT 步中逐步添加少量高斯噪声到样本,从而产生噪声样本 x_(1),dots,x_(T)\mathbf{x}_{1}, \ldots, \mathbf{x}_{T} 。在时间步 tt 的噪声样本分布被定义为如下条件分布:
q(x_(t)∣x_(t-1))=N(x_(t);sqrt(1-beta_(t))x_(t-1),beta_(t)I)q\left(\mathbf{x}_{t} \mid \mathbf{x}_{t-1}\right)=\mathcal{N}\left(\mathbf{x}_{t} ; \sqrt{1-\beta_{t}} \mathbf{x}_{t-1}, \beta_{t} \mathbf{I}\right)
beta_(t)in(0,1)\beta_{t} \in(0,1) 是一个方差参数,它控制噪声的步长。作为 t rarr oot \rightarrow \infty ,输入样本 x_(0)\mathbf{x}_{0} 失去了其可区分的特征。最后,当 t=T,x_(T)t=T, \mathbf{x}_{T} 符合各向同性高斯分布。从公式 (2) 中,我们可以定义一个封闭形式的可追踪后验,它在所有时间步上分解如下:
q(x_(1:T)∣x_(0))=prod_(t=1)^(T)q(x_(t)∣x_(t-1))q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_{0}\right)=\prod_{t=1}^{T} q\left(\mathbf{x}_{t} \mid \mathbf{x}_{t-1}\right)
现在, x_(t)\mathbf{x}_{t} 可以在任何任意的时间步 tt 使用重参数化技巧进行采样。令 alpha_(t)=prod_(i=1)^(t)1-beta_(i)\alpha_{t}=\prod_{i=1}^{t} 1-\beta_{i} :
q(x_(t)∣x_(0))=N(x_(t);sqrt(alpha_(t))x_(0),(1-alpha_(t))I)q\left(\mathbf{x}_{t} \mid \mathbf{x}_{0}\right)=\mathcal{N}\left(\mathbf{x}_{t} ; \sqrt{\alpha_{t}} \mathbf{x}_{0},\left(1-\alpha_{t}\right) \mathbf{I}\right)
反向扩散。在反向过程中,为了从 q(x_(t-1)∣x_(t))q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right) 采样,目标是使用高斯噪声输入 x_(T)∼N(0,I)\mathbf{x}_{T} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) 重新创建真实样本 x_(0)\mathbf{x}_{0} 。与正向扩散不同, q(x_(t-1)∣x_(t))q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right) 不是解析可处理的,因此需要学习一个模型 p_(theta)p_{\theta} 来近似条件分布,如下所示:
{:[p_(theta)(x_(0:T))=p(x_(T))prod_(t=1)^(T)p_(theta)(x_(t-1)∣x_(t))],[p_(theta)(x_(t-1)∣x_(t))=N(x_(t-1);mu_(theta)(x_(t),t),Sigma_(theta)(x_(t),t))]:}\begin{aligned}
p_{\theta}\left(\mathbf{x}_{0: T}\right) & =p\left(\mathbf{x}_{T}\right) \prod_{t=1}^{T} p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right) \\
p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right) & =\mathcal{N}\left(\mathbf{x}_{t-1} ; \mu_{\theta}\left(\mathbf{x}_{t}, t\right), \Sigma_{\theta}\left(\mathbf{x}_{t}, t\right)\right)
\end{aligned}
mu_(theta)\mu_{\theta} 和 Sigma_(theta)\Sigma_{\theta} 通过神经网络学习得到。结果证明,对输入 x_(0)\mathbf{x}_{0} 进行条件化可以得到一个可处理的逆条件概率
q(x_(t-1)∣x_(t),x_(0))=N(x_(t-1);( tilde(mu))(x_(t),x_(0)), tilde(beta)_(t)I)q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_{0}\right)=\mathcal{N}\left(\mathbf{x}_{t-1} ; \tilde{\mu}\left(\mathbf{x}_{t}, \mathbf{x}_{0}\right), \tilde{\beta}_{t} \mathbf{I}\right)
tilde(mu)\tilde{\mu} 和 tilde(beta)_(t)\tilde{\beta}_{t} 是真实均值和方差。学习目标随后通过重新参数化以最小化以下均方误差损失,被表述为一个简化的 ELBO 目标。
L_("simple ")=sum_(t=1)^(T)E_(x_(0),epsilon_(t))[||epsilon_(t)-epsilon_(theta)(x_(t),t)||_(2)^(2)]\mathcal{L}_{\text {simple }}=\sum_{t=1}^{T} \mathbb{E}_{\mathbf{x}_{0}, \epsilon_{t}}\left[\left\|\epsilon_{t}-\epsilon_{\theta}\left(\mathbf{x}_{t}, t\right)\right\|_{2}^{2}\right]
其中 epsilon_(t)∼N(0,I)\epsilon_{t} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) 是通过从 x_(0)\mathbf{x}_{0} 的重新参数化得到的一种分析形式的噪声,如[11]所示。

图 1:因果扩散自动编码器框架。左侧详细说明了因果扩散自动编码器(CausalDiffAE)的训练过程,通过编码到因果表示 z_("causal ")\mathbf{z}_{\text {causal }} ,并使用条件 DDIM 解码器,基于 z_("causal ")\mathbf{z}_{\text {causal }} 和 x_(T)\mathbf{x}_{T} 进行图像重建。右侧展示了使用训练好的因果扩散自动编码器模型进行基于 DDIM 的反事实生成过程。
DPMs 通过正向过程产生潜在变量 x_(1:T)\mathbf{x}_{1: T} 。然而,这些变量是随机的[24]。Song 等人提出了一种名为去噪扩散隐式模型(DDIM)的 DPM,它实现了一个确定性过程如下:
x_(t-1)=sqrt(alpha_(t-1))((x_(t)-sqrt(1-alpha_(t))epsilon_(theta)(x_(t),t))/(sqrt(alpha_(t))))+sqrt(1-alpha_(t-1))epsilon_(theta)(x_(t),t)\mathbf{x}_{t-1}=\sqrt{\alpha_{t-1}}\left(\frac{\mathbf{x}_{t}-\sqrt{1-\alpha_{t}} \epsilon_{\theta}\left(\mathbf{x}_{t}, t\right)}{\sqrt{\alpha_{t}}}\right)+\sqrt{1-\alpha_{t-1}} \epsilon_{\theta}\left(\mathbf{x}_{t}, t\right)
具有以下确定性解码过程
q(x_(t-1)∣x_(t),x_(0))=N(sqrt(alpha_(t-1))x_(0)+sqrt(1-alpha_(t-1))(x_(t)-sqrt(alpha_(t))x_(0))/(sqrt(1-alpha_(t))),0)q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_{0}\right)=\mathcal{N}\left(\sqrt{\alpha_{t-1}} \mathbf{x}_{0}+\sqrt{1-\alpha_{t-1}} \frac{\mathbf{x}_{t}-\sqrt{\alpha_{t}} \mathbf{x}_{0}}{\sqrt{1-\alpha_{t}}}, \mathbf{0}\right)
该公式保持了 DDPM 的边缘分布 q(x_(t)∣x_(0))=q\left(\mathbf{x}_{t} \mid \mathbf{x}_{0}\right)= N(sqrt(alpha_(t-1))x_(0),(1-alpha_(t))I)\mathcal{N}\left(\sqrt{\alpha_{t-1}} \mathbf{x}_{0},\left(1-\alpha_{t}\right) \mathbf{I}\right) 。结果证明,这种公式与 DDPM 具有相同的目标和解决方案,只是在采样程序上有所不同。因此,我们可以确定地获得与给定图像 x_(0)\mathbf{x}_{0} 相对应的噪声图 x_(T)\mathbf{x}_{T} 。
4 因果扩散自动编码器
现有基于扩散的可控生成方法忽略了生成因素之间存在因果关系的情况,并且不支持反事实生成。为了解决这一问题,我们提出了 CausalDiffAE,一个基于扩散的因果表示学习框架,以实现反事实生成。首先,我们定义一个潜在结构因果模型(SCM)来描述语义因果表示作为学习到的噪声编码的函数。在扩散自动编码器[24]的情况下,语义潜在表示 z_("sem ")\mathbf{z}_{\text {sem }} 捕获高级语义信息,而 x_(T)\mathbf{x}_{T} 捕获低级随机信息。在我们的公式中,我们学习一个因果表示 z_("causal ")\mathbf{z}_{\text {causal }} ,它捕获因果相关的信息。这两个潜在变量( z_("causal "),x_(T)\mathbf{z}_{\text {causal }}, \mathbf{x}_{T} )共同捕获图像中所有详细的因果语义和随机性。其次,给定一个训练好的 CausalDiffAE 模型,我们提出了一种反事实生成算法,该算法利用 do (*)(\cdot) 干预和 DDIM 采样算法。CausalDiffAE 的整体框架如图 1 所示。
4.1 因果编码
令 x_(0)inR^(d)\mathbf{x}_{0} \in \mathbb{R}^{d} 为观察到的输入图像。我们进行正向扩散过程,直到得到一组 TT 混扰样本 {x_(1),x_(2),dots,x_(T)}\left\{\mathbf{x}_{1}, \mathbf{x}_{2}, \ldots, \mathbf{x}_{T}\right\} ,每个样本具有不同的噪声尺度。假设存在 nn 个抽象因果变量,描述观察图像的高级语义。为了学习有意义的表示,我们提出将输入图像 x_(0)\mathbf{x}_{0} 编码到低维噪声编码 uinR^(n)\mathbf{u} \in \mathbb{R}^{n} 。然后,我们将噪声编码映射到与抽象因果变量对应的潜在因果因子 z_("causal ")inR^(n)\mathbf{z}_{\text {causal }} \in \mathbb{R}^{n} 。在这个公式中,每个噪声项 u_(i)u_{i} 是 SCM 中因果变量 z_(i)z_{i} 的外生噪声项。令 A\mathbf{A} 为编码底层因素之间因果图的邻接矩阵,其中 A_(ji)A_{j i} 表示 z_(j)z_{j} 是 z_(i)z_{i} 的原因。然后,我们如下参数化因果变量之间的机制:
z_(i)=f_(i)(z_(pa_(i)),u_(i))z_{i}=f_{i}\left(z_{\mathbf{p a}_{i}}, u_{i}\right)
f_(i)f_{i} 是生成因果变量 z_(i)z_{i} 作为其父节点和外部噪声项的函数的因果机制, z_(pa_(i))z_{\mathbf{p a}_{i}} 表示因素 z_(i)z_{i} 的因果父节点。在实践中,我们可以将 f_(i)f_{i} 实现为一个后非线性加性噪声模型,以便
{:[z=(I-A^(T))^(-1)u],[z_(i)=f_(i)(A_(i)o.z;nu_(i))+u_(i)]:}\begin{aligned}
\mathbf{z} & =\left(I-\mathbf{A}^{T}\right)^{-1} \mathbf{u} \\
z_{i} & =f_{i}\left(\mathbf{A}_{i} \odot \mathbf{z} ; \nu_{i}\right)+u_{i}
\end{aligned}
其中 nu_(i)\nu_{i} 是参数化每个机制的神经网络的参数, o.\odot 是逐元素乘积, z_("causal ")=\mathbf{z}_{\text {causal }}= {z_(1),dots,z_(n)}\left\{z_{1}, \ldots, z_{n}\right\} 。本模块使用神经结构因果模型捕捉潜在变量之间的因果关系。在本工作中,我们假设因果图是已知的,因为我们专注于反事实生成。然而,一个更端到端的框架可能包括一个因果发现组件。有关更详细的讨论,请参阅附录 C。
4.2 生成模型
令 x_(0)\mathbf{x}_{0} 表示高维输入图像, yinR^(n)\mathbf{y} \in \mathbb{R}^{n} 表示辅助弱监督信号。然后,因果扩散自动编码器的生成过程可以分解如下: