这是用户在 2024-9-23 19:49 为 https://ax9rsosmvaht.compat.objectstorage.ap-seoul-1.oraclecloud.com/translate-cache/web_reporting/2... 保存的双语快照页面,由 沉浸式翻译 提供双语支持。了解如何保存?


基于能量的波尔兹曼分布高效采样扩散发生器

Yan Wanga,  Ling Guob,  Hao Wu*c,  Tao Zhoud
Yan Wanga, Ling Guob, Hao Wu*c, Tao Zhoud
Abstract 摘要


波尔兹曼分布采样,尤其是与高维和复杂能量函数相关的采样,在许多领域都是一项重大挑战。在这项工作中,我们提出了基于能量的扩散发生器(EDG),这是一种整合了变异自动编码器和扩散模型思想的新方法。EDG 利用解码器将潜变量从简单分布转换为近似目标波尔兹曼分布的样本,而基于扩散的编码器则在训练过程中提供对库尔贝-莱布勒发散的精确估计。值得注意的是,EDG 无需模拟,因此在训练过程中无需求解常微分方程或随机微分方程。此外,通过消除解码器中的双射性等约束条件,EDG 允许进行灵活的网络设计。通过实证评估,我们证明了 EDG 在各种复杂分布任务中的卓越性能,其表现优于现有方法。


关键词:
玻尔兹曼分布 , 基于能量的模型 , 生成模型 , 扩散模型 , 变分自动编码器
\affiliation \隶属关系


organization= 同济大学数学科学学院,city= 上海、 国家=中国 /affiliationorganization=Department of Mathematics, Shanghaiormal University,city=Shanghai, country=China /affiliationorganization=School of Mathematical Sciences, Institute of Natural Sciences and MOE-LSC、


Shanghai Jiao Tong University,city=Shanghai, country=China, addressline=, *hwu81@sjtu.edu.cn affiliation\organization=LSEC, Institute of Computational Mathematics and Scientific/Engineering Computing, AMSS, Chinese Academy of Sciences,city=Beijing, country=China


1 简介


在计算化学、统计物理和机器学习等多个领域,从与高维复杂能量函数相对应的玻尔兹曼分布中进行采样的挑战无处不在[1, 2] 。与数据驱动生成模型的训练任务不同,波尔兹曼分布的采样任务可以利用预先采样的数据来学习复杂的分布,但由于缺乏现成的数据[3, 4] ,因此带来了独特而巨大的挑战。例如,模拟伊辛模型的相变可以看作是给定能量函数的采样问题,这是一个复杂而困难的问题,至今尚未得到有效解决[5, 6]


马尔可夫链蒙特卡罗(MCMC)方法[7] 以及布朗和哈密顿动力学[89, 10, 11] 为解决从高维分布中采样的难题提供了关键的解决方案。这些方法通过迭代生成候选样本和更新样本,最终在无限采样步骤的极限实现渐近无偏性。近年来,研究人员提出了自适应 MCMC 作为生成候选样本的策略,在提高采样过程的效率和有效性方面取得了显著进展[12, 13, 14] 。然而,MCMC 的混合时间过长仍然制约了其性能。一些研究表明,在 MCMC 中使用神经网络构建和优化提议分布可以显著提高其效率[13, 15, 16] 。然而,目前仍缺乏有效且适应性广的损失函数来促进这种优化。


变量推理(Variational inference,VI)是解决棘手分布问题的另一种重要方法。变分推理利用能够快速生成样本的生成器来逼近目标波尔兹曼分布,然后对生成器的参数进行优化,以最小化生成样本的分布与目标分布之间的统计距离,如库尔巴克-莱伯勒(KL)发散。由于归一化流(NF)能够模拟复杂分布并提供明确的概率密度函数,它已被广泛应用于构建 VI 方法的生成器 [1718, 19, 20, 21, 2223, 24, 25, 26] .然而,NF 的双射性质对其有效容量造成了限制,往往使其不足以完成某些采样任务。考虑到目标密度函数和生成的样本,斯坦因差异[27, 28] 提供了另一种评估拟合优度的方法,而核函数及其梯度的计算限制了它在高维任务中的表现。此外,MCMC 与 VI 方法的结合也是当前研究的一个焦点 [2930, 31, 32, 33, 34] . 这种组合试图利用两种方法的优势,为解决与高维分布采样相关的挑战和提高概率建模的效率提供了一条很有前景的途径。


随着基于扩散的生成模型[35, 36, 37, 38] 的蓬勃发展,它们已被应用于解决抽样问题中的难题。通过训练随时间变化的分数匹配神经网络,[39, 40, 41] 中提出的方法将高斯分布塑造成复杂的目标密度,并采用 KL 发散作为损失函数。为了缓解模式搜索问题,[42] 引入了对数方差损失,显示出良好的特性。此外,[43] 中概述了另一种训练目标,该目标依赖于能量函数的灵活插值,对多模态目标有很大改进。然而,这些方法的一个共同缺点是依赖数值微分方程求解器来计算时间积分,这可能会导致大量的计算成本。


在这项研究工作中,我们从变异自动编码器(VAE)技术[44] 和扩散模型中汲取灵感,提出了一种称为基于能量的扩散发生器(EDG)的新方法。EDG 的架构与 VAE 非常相似,包括一个解码器和一个编码器。解码器可以灵活地将根据可控分布分布的潜变量映射到样本,而无需施加诸如双向性之类的约束,我们在这项工作中设计了一个基于广义哈密尔顿动力学的解码器,以提高采样效率。编码器利用扩散过程,能够应用分数匹配技术对给定样本的潜变量条件分布进行精确高效的建模。与现有的基于扩散的方法不同,EDG 的损失函数便于以随机小批量方式计算无偏估计值,无需在训练过程中对常微分方程或随机微分方程进行数值求解。数值实验最终证明了 EDG 的有效性。


2 前言和设置


在这项工作中,我们将深入研究制作生成模型的任务,以便从预定能量 U:dU:\mathbb{R}^{d}\to\mathbb{R}italic_U : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT → blackboard_R 驱动的玻尔兹曼分布中采样:

π(x)=1Zexp(U(x)),𝜋𝑥1𝑍𝑈𝑥\pi(x)=\frac{1}{Z}\exp(-U(x)),italic_π ( italic_x ) = divide start_ARG 1 end_ARG start_ARG italic_Z end_ARG roman_exp ( - italic_U ( italic_x ) ) ,


其中归一化常数 Z=exp(U(x))dxZ=\int\exp(-U(x))\mathrm{d}xitalic_Z = ∫ roman_exp ( - italic_U ( italic_x ) ) roman_d italic_x 通常难以计算。为了应对这一挑战,玻尔兹曼发生器 [18] 及其各种扩展 [2425, 26], 近年来已成为一种突出的技术。这些方法利用 NF 对可训练的分析密度函数进行参数化,并通过代用密度与 π\piitalic_π 之间 KL 发散的最小化实现参数优化。然而,与典型的生成模型不同,追求精确的概率密度计算对 NF 施加了大量限制:每个变换层都必须是双射,而且其雅各矩阵的行列式可以轻松计算。这些要求从本质上限制了 NF 对复杂分布进行有效建模的能力。


现在,我们的重点转移到与 VAE 类似的信号发生器上。这种发生器通过解码器产生的采样为

x|z0pD(x|z0,ϕ),similar-toconditional𝑥subscript𝑧0subscript𝑝𝐷conditional𝑥subscript𝑧0italic-ϕx|z_{0}\sim p_{D}(x|z_{0},\phi),italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_ϕ ) ,


其中 z0pD(z0)z_{0}\sim p_{D}(z_{0})italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) 是一个从已知先验分布(通常是标准多元正态分布)中抽取的潜在变量。参数 ϕ\phiitalic_ϕ 是解码器的特征,我们将 pD(x|z0;ϕ)p_{D}(x|z_{0};\phi)italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ; italic_ϕ ) 定义为高斯分布 𝒩(x|μ(z0;ϕ),Σ(z0;ϕ))\mathcal{N}(x|\mu(z_{0};\phi),\Sigma(z_{0};\phi))caligraphic_N ( italic_x | italic_μ ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ; italic_ϕ ) , roman_Σ ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ; italic_ϕ ) ) ,其中 μ\muitalic_μΣ\Sigmaroman_Σ 均由神经网络 (NN) 参数化。与 VAE 类似,我们的目标是训练网络 μ\muitalic_μΣ\Sigmaroman_Σ 使生成样本的边际分布 pD(x)p_{D}(x)italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x ) 与目标分布一致。


需要注意的是,与传统的数据驱动 VAE 不同,我们无法获得目标分布 π(x)\pi(x)italic_π ( italic_x ) 的样本。事实上,获取此类样本正是生成器的目标。因此,KL 发散 DKL(π(x)||pD(x))D_{KL}\left(\pi(x)||p_{D}(x)\right)italic_D start_POSTSUBSCRIPT italic_K italic_L end_POSTSUBSCRIPT ( italic_π ( italic_x ) | | italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x ) ) 的变分近似值不能用于训练模型。相反,在这项工作中,我们考虑以下发散及其上限:

DKL(pD(x)||π(x))\displaystyle D_{KL}\left(p_{D}(x)||\pi(x)\right)italic_D start_POSTSUBSCRIPT italic_K italic_L end_POSTSUBSCRIPT ( italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x ) | | italic_π ( italic_x ) ) \displaystyle\leq DKL(pD(z0)pD(x|z0,ϕ)||π(x)pE(z0|x,θ))\displaystyle D_{KL}\left(p_{D}(z_{0})\cdot p_{D}(x|z_{0},\phi)||\pi(x)\cdot p% _{E}(z_{0}|x,\theta)\right)italic_D start_POSTSUBSCRIPT italic_K italic_L end_POSTSUBSCRIPT ( italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ⋅ italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_ϕ ) | | italic_π ( italic_x ) ⋅ italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x , italic_θ ) ) (1)
=\displaystyle== 𝔼pD(z0)pD(x|z0,ϕ)[logpD(z0)pD(x|z0,ϕ)pE(z0|x,θ)+U(x)]subscript𝔼subscript𝑝𝐷subscript𝑧0subscript𝑝𝐷conditional𝑥subscript𝑧0italic-ϕdelimited-[]subscript𝑝𝐷subscript𝑧0subscript𝑝𝐷conditional𝑥subscript𝑧0italic-ϕsubscript𝑝𝐸conditionalsubscript𝑧0𝑥𝜃𝑈𝑥\displaystyle\mathbb{E}_{p_{D}(z_{0})\cdot p_{D}(x|z_{0},\phi)}\left[\log\frac% {p_{D}(z_{0})p_{D}(x|z_{0},\phi)}{p_{E}(z_{0}|x,\theta)}+U(x)\right]blackboard_E start_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ⋅ italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_ϕ ) end_POSTSUBSCRIPT [ roman_log divide start_ARG italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_ϕ ) end_ARG start_ARG italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x , italic_θ ) end_ARG + italic_U ( italic_x ) ]
+logZ.𝑍\displaystyle+\log Z.+ roman_log italic_Z .


在这里,参数分布 pE(z0|x,θ)p_{E}(z_{0}|x,\theta)italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x , italic_θ ) 定义了从 xxitalic_x 映射到潜变量 zzitalic_z 的编码器,如果 pE(z|x,θ)p_{E}(z|x,\theta)italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_z | italic_x , italic_θ ) 与从解码器推导出的给定 xxitalic_xzzitalic_z 的条件分布相匹配,则实现了相等。


我们似乎只是增加了问题的复杂性,因为我们仍然需要近似条件分布。不过,在接下来的章节中,我们将演示如何利用扩散模型 [35, 37] 有效地构建编码器,并优化所有参数,而无需数值求解常微分方程或随机微分方程。


3 能量扩散发生器


扩散模型 [37, 38] 是近年来出现的一种估算数据分布的高效方法。其核心思想是构建一个扩散过程,逐步将数据转化为简单的白噪声,然后学习反向过程,从噪声中恢复数据分布。在这项工作中,我们应用了扩散模型的原理,在潜空间中加入了扩散过程,使我们能够有效地克服等式 (1) 所定义的采样问题的变分框架所带来的挑战。我们将这种方法产生的模型称为基于能量的扩散发生器 (EDG)。


3.1 模型架构


在 EDG 框架中,我们从潜在变量 z0z_{0}italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT 启动一个扩散过程,并将其与解码器相结合,形成我们所说的 "解码过程":

z0RDpD(z0)𝒩(x|0,I),x|z0pD(x|z0;ϕ)formulae-sequencesubscript𝑧0superscript𝑅𝐷similar-tosubscript𝑝𝐷subscript𝑧0𝒩conditional𝑥0𝐼similar-toconditional𝑥subscript𝑧0subscript𝑝𝐷conditional𝑥subscript𝑧0italic-ϕ\displaystyle z_{0}\in R^{D}\sim p_{D}(z_{0})\triangleq\mathcal{N}(x|0,I),% \quad x|z_{0}\sim p_{D}(x|z_{0};\phi)italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ italic_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ∼ italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ≜ caligraphic_N ( italic_x | 0 , italic_I ) , italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ; italic_ϕ )
dzt=f(zt,t)dt+g(t)dWt,t[0,T]formulae-sequencedsubscript𝑧𝑡𝑓subscript𝑧𝑡𝑡d𝑡𝑔𝑡dsubscript𝑊𝑡𝑡0𝑇\displaystyle\mathrm{d}z_{t}=f(z_{t},t)\mathrm{d}t+g(t)\mathrm{d}W_{t},\quad t% \in[0,T]roman_d italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_f ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) roman_d italic_t + italic_g ( italic_t ) roman_d italic_W start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ∈ [ 0 , italic_T ] (2)


其中, WtW_{t}italic_W start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT 为标准维纳过程, f(,t):RDRDf(\cdot,t):R^{D}\rightarrow R^{D}italic_f ( ⋅ , italic_t ) : italic_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT → italic_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT 为漂移系数, g():RRg(\cdot):R\rightarrow Ritalic_g ( ⋅ ) : italic_R → italic_R 为扩散系数。为了简化符号,我们将解码过程定义的概率分布表示为 pDp_{D}italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT 。在扩散模型中应用的典型 SDE 中,有两个关键条件:(a) 过渡密度 pD(zt|z0)p_{D}(z_{t}|z_{0})italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) 可以分析计算,而无需数值求解福克-普朗克方程;(b) zTz_{T}italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPTpD(zT)pD(zT|z0)p_{D}(z_{T})\approx p_{D}(z_{T}|z_{0})italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) ≈ italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) 近似无信息。


如果我们只考虑潜在扩散过程 z[]={zt}t[0,T]z_{[\cdot]}=\{z_{t}\}_{t\in[0,T]}italic_z start_POSTSUBSCRIPT [ ⋅ ] end_POSTSUBSCRIPT = { italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_t ∈ [ 0 , italic_T ] end_POSTSUBSCRIPT 的统计特性,那么它是没有信息量的,只能描述从一种简单噪声到另一种简单噪声的过渡。然而,当我们考虑到给定样本 xxitalic_xztz_{t}italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT 的条件分布时, z[:]z_{[:]}italic_z start_POSTSUBSCRIPT [ : ] end_POSTSUBSCRIPT 过程就代表了复杂条件分布 pD(z0|x)pD(z0)pD(x|z0)p_{D}(z_{0}|x)\propto p_{D}(z_{0})\cdot p_{D}(x|z_{0})italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x ) ∝ italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ⋅ italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) 逐渐转变为可控分布 pD(zT|x)=pD(zT)p_{D}(z_{T}|x)=p_{D}(z_{T})italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT | italic_x ) = italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) 的过程、其中 zTz_{T}italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPTxxitalic_x 之间的独立性来自 z0z_{0}italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPTzTz_{T}italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT 之间的独立性(见 A )。这意味着,从 zTpD(zT)z_{T}\sim p_{D}(z_{T})italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ∼ italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) 开始,我们可以通过模拟以下反向时间扩散方程 [45]pD(z0|x)p_{D}(z_{0}|x)italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x ) 获取样本:

dzt~=(f(zt~,t~)g(t~)2zt~logpD(zt~|x))dt~+g(t~)dWt~,dsubscript𝑧~𝑡𝑓subscript𝑧~𝑡~𝑡𝑔superscript~𝑡2subscriptsubscript𝑧~𝑡subscript𝑝𝐷conditionalsubscript𝑧~𝑡𝑥d~𝑡𝑔~𝑡dsubscript𝑊~𝑡\mathrm{d}z_{\tilde{t}}=-\left(f(z_{\tilde{t}},\tilde{t})-g(\tilde{t})^{2}% \nabla_{z_{\tilde{t}}}\log p_{D}(z_{\tilde{t}}|x)\right)\mathrm{d}\tilde{t}+g(% \tilde{t})\mathrm{d}W_{\tilde{t}},roman_d italic_z start_POSTSUBSCRIPT over~ start_ARG italic_t end_ARG end_POSTSUBSCRIPT = - ( italic_f ( italic_z start_POSTSUBSCRIPT over~ start_ARG italic_t end_ARG end_POSTSUBSCRIPT , over~ start_ARG italic_t end_ARG ) - italic_g ( over~ start_ARG italic_t end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT over~ start_ARG italic_t end_ARG end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT over~ start_ARG italic_t end_ARG end_POSTSUBSCRIPT | italic_x ) ) roman_d over~ start_ARG italic_t end_ARG + italic_g ( over~ start_ARG italic_t end_ARG ) roman_d italic_W start_POSTSUBSCRIPT over~ start_ARG italic_t end_ARG end_POSTSUBSCRIPT , (3)


其中 t~=Tt\tilde{t}=T-tover~ start_ARG italic_t end_ARG = italic_T - italic_t 表示反向时间。与传统的扩散模型一样,由于得分函数 zt~logpD(zt~|x)\nabla_{z_{\tilde{t}}}\log p_{D}(z_{\tilde{t}}|x)∇ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT over~ start_ARG italic_t end_ARG end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT over~ start_ARG italic_t end_ARG end_POSTSUBSCRIPT | italic_x ) 的难处理性,这种模拟的实际实施具有挑战性,因此我们也使用神经网络来近似得分函数,表示为 s(zt~,x,t~;θ)s(z_{\tilde{t}},x,\tilde{t};\theta)italic_s ( italic_z start_POSTSUBSCRIPT over~ start_ARG italic_t end_ARG end_POSTSUBSCRIPT , italic_x , over~ start_ARG italic_t end_ARG ; italic_θ ) 。这种近似导致了我们所说的 "编码过程",它是通过整合参数反向时间扩散过程和 xxitalic_x 的目标分布来实现的:

xπ(x),zTpE(zT)pD(zT)formulae-sequencesimilar-to𝑥𝜋𝑥similar-tosubscript𝑧𝑇subscript𝑝𝐸subscript𝑧𝑇subscript𝑝𝐷subscript𝑧𝑇\displaystyle x\sim\pi(x),\quad z_{T}\sim p_{E}(z_{T})\triangleq p_{D}(z_{T})italic_x ∼ italic_π ( italic_x ) , italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ∼ italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) ≜ italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT )
dzt~=(f(zt~,t~)g(t~)2s(zt~,x,t~;θ))dt~+g(t~)dWt~,t~=Tt.formulae-sequencedsubscript𝑧~𝑡𝑓subscript𝑧~𝑡~𝑡𝑔superscript~𝑡2𝑠subscript𝑧~𝑡𝑥~𝑡𝜃d~𝑡𝑔~𝑡dsubscript𝑊~𝑡~𝑡𝑇𝑡\displaystyle\mathrm{d}z_{\tilde{t}}=-\left(f(z_{\tilde{t}},\tilde{t})-g(% \tilde{t})^{2}s(z_{\tilde{t}},x,\tilde{t};\theta)\right)\mathrm{d}\tilde{t}+g(% \tilde{t})\mathrm{d}W_{\tilde{t}},\quad\tilde{t}=T-t.roman_d italic_z start_POSTSUBSCRIPT over~ start_ARG italic_t end_ARG end_POSTSUBSCRIPT = - ( italic_f ( italic_z start_POSTSUBSCRIPT over~ start_ARG italic_t end_ARG end_POSTSUBSCRIPT , over~ start_ARG italic_t end_ARG ) - italic_g ( over~ start_ARG italic_t end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_s ( italic_z start_POSTSUBSCRIPT over~ start_ARG italic_t end_ARG end_POSTSUBSCRIPT , italic_x , over~ start_ARG italic_t end_ARG ; italic_θ ) ) roman_d over~ start_ARG italic_t end_ARG + italic_g ( over~ start_ARG italic_t end_ARG ) roman_d italic_W start_POSTSUBSCRIPT over~ start_ARG italic_t end_ARG end_POSTSUBSCRIPT , over~ start_ARG italic_t end_ARG = italic_T - italic_t . (4)


为简化符号,本文中我们将编码过程定义的分布称为 pEp_{E}italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT


1 直观描述了解码和编码过程。需要强调的是,潜在扩散模型主要用于解决数据驱动场景中的生成建模问题,最近受到了广泛关注[46, 47, 48] 。他们的主要想法是使用预先训练好的编码器和解码器来获得一个既能有效表示数据又能促进高效采样的潜在空间,并通过扩散模型来学习潜在变量的分布。我们的 EDG 模型利用类似的理念来解决基于能量的采样问题。EDG 与之前的潜变量扩散模型在结构和算法上的主要区别如下:首先,在 EDG 中,扩散模型本身就是编码器,无需单独的编码器;其次,通过使用统一的损失函数,解码器与扩散模型共同训练(见第 3.2 节)。

Refer to caption

图 1:解码和编码过程的概率图,灰色部分为模型的可训练部分。


下面,我们将介绍 EDG 模块的构造细节,这些模块将在我们的实验中使用。在实际应用中,可以根据需要设计更有效的神经网络。


3.1.1边界条件引导的分数函数模型


考虑到真实分数函数满足 t=0,Tt=0,Titalic_t = 0 , italic_T 的以下边界条件:

z0logpD(z0|x)subscriptsubscript𝑧0subscript𝑝𝐷conditionalsubscript𝑧0𝑥\displaystyle\nabla_{z_{0}}\log p_{D}(z_{0}|x)∇ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x ) =\displaystyle== z0[logpD(z0|x)+logpD(x)]subscriptsubscript𝑧0subscript𝑝𝐷conditionalsubscript𝑧0𝑥subscript𝑝𝐷𝑥\displaystyle\nabla_{z_{0}}\left[\log p_{D}(z_{0}|x)+\log p_{D}(x)\right]∇ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x ) + roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x ) ]
=\displaystyle== z0logpD(x,z0)subscriptsubscript𝑧0subscript𝑝𝐷𝑥subscript𝑧0\displaystyle\nabla_{z_{0}}\log p_{D}(x,z_{0})∇ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT )
=\displaystyle== z0[logpD(x|z0)+logpD(z0)]subscriptsubscript𝑧0subscript𝑝𝐷conditional𝑥subscript𝑧0subscript𝑝𝐷subscript𝑧0\displaystyle\nabla_{z_{0}}\left[\log p_{D}(x|z_{0})+\log p_{D}(z_{0})\right]∇ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) + roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ]

and 

zTlogpD(zT|x)=zTlogpD(zT),subscriptsubscript𝑧𝑇subscript𝑝𝐷conditionalsubscript𝑧𝑇𝑥subscriptsubscript𝑧𝑇subscript𝑝𝐷subscript𝑧𝑇\nabla_{z_{T}}\log p_{D}(z_{T}|x)=\nabla_{z_{T}}\log p_{D}(z_{T}),∇ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT | italic_x ) = ∇ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) ,


我们建议将 s(z,x,t;θ)s(z,x,t;\theta)italic_s ( italic_z , italic_x , italic_t ; italic_θ ) 表述为

s(z,x,t;θ)𝑠𝑧𝑥𝑡𝜃\displaystyle s(z,x,t;\theta)italic_s ( italic_z , italic_x , italic_t ; italic_θ ) =\displaystyle== (1tT)z0[logpD(x|z0=z)+logpD(z0=z)]1𝑡𝑇subscriptsubscript𝑧0subscript𝑝𝐷conditional𝑥subscript𝑧0𝑧subscript𝑝𝐷subscript𝑧0𝑧\displaystyle\left(1-\frac{t}{T}\right)\cdot\nabla_{z_{0}}\left[\log p_{D}(x|z% _{0}=z)+\log p_{D}(z_{0}=z)\right]( 1 - divide start_ARG italic_t end_ARG start_ARG italic_T end_ARG ) ⋅ ∇ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = italic_z ) + roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = italic_z ) ]
+tTzTlogpD(zT=z)+tT(1tT)s(z,x,t;θ),𝑡𝑇subscriptsubscript𝑧𝑇subscript𝑝𝐷subscript𝑧𝑇𝑧𝑡𝑇1𝑡𝑇superscript𝑠𝑧𝑥𝑡𝜃\displaystyle+\frac{t}{T}\cdot\nabla_{z_{T}}\log p_{D}(z_{T}=z)+\frac{t}{T}% \left(1-\frac{t}{T}\right)s^{\prime}(z,x,t;\theta),+ divide start_ARG italic_t end_ARG start_ARG italic_T end_ARG ⋅ ∇ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT = italic_z ) + divide start_ARG italic_t end_ARG start_ARG italic_T end_ARG ( 1 - divide start_ARG italic_t end_ARG start_ARG italic_T end_ARG ) italic_s start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_z , italic_x , italic_t ; italic_θ ) ,


其中, s(z,x,t;θ)s^{\prime}(z,x,t;\theta)italic_s start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_z , italic_x , italic_t ; italic_θ ) 是要训练的神经网络。这种表述方式可确保 ssitalic_s 的误差在 t=0t=0italic_t = 0t=Tt=Titalic_t = italic_T 中均为零。


3.1.2 基于广义哈密顿动力学的解码器


受广义哈密顿动力学(GHD)[12, 13] 的启发,解码器通过以下过程生成输出 xxitalic_x 。首先,根据潜变量 z0z_{0}italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT 生成初始样本和速度 ( y,vy,vitalic_y , italic_v )。然后,对 (y,v)(y,v)( italic_y , italic_v ) 进行如下迭代更新:

v𝑣\displaystyle vitalic_v :=assign\displaystyle:=:= vϵ(l;ϕ)2(U(y)eϵ02Qv(y,U(y),l;ϕ)+Tv(y,U(y),l;ϕ)),𝑣italic-ϵ𝑙italic-ϕ2direct-product𝑈𝑦superscript𝑒subscriptitalic-ϵ02subscript𝑄𝑣𝑦𝑈𝑦𝑙italic-ϕsubscript𝑇𝑣𝑦𝑈𝑦𝑙italic-ϕ\displaystyle v-\frac{\epsilon(l;\phi)}{2}\left(\nabla U(y)\odot e^{\frac{% \epsilon_{0}}{2}Q_{v}(y,\nabla U(y),l;\phi)}+T_{v}(y,\nabla U(y),l;\phi)\right),italic_v - divide start_ARG italic_ϵ ( italic_l ; italic_ϕ ) end_ARG start_ARG 2 end_ARG ( ∇ italic_U ( italic_y ) ⊙ italic_e start_POSTSUPERSCRIPT divide start_ARG italic_ϵ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG italic_Q start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ( italic_y , ∇ italic_U ( italic_y ) , italic_l ; italic_ϕ ) end_POSTSUPERSCRIPT + italic_T start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ( italic_y , ∇ italic_U ( italic_y ) , italic_l ; italic_ϕ ) ) ,
y𝑦\displaystyle yitalic_y :=assign\displaystyle:=:= y+ϵ(l;ϕ)(vkeϵ0Qy(vk,l;ϕ)+Ty(vk,l;ϕ)),𝑦italic-ϵ𝑙italic-ϕdirect-productsubscript𝑣𝑘superscript𝑒subscriptitalic-ϵ0subscript𝑄𝑦subscript𝑣𝑘𝑙italic-ϕsubscript𝑇𝑦subscript𝑣𝑘𝑙italic-ϕ\displaystyle y+\epsilon(l;\phi)\left(v_{k}\odot e^{\epsilon_{0}Q_{y}(v_{k},l;% \phi)}+T_{y}(v_{k},l;\phi)\right),italic_y + italic_ϵ ( italic_l ; italic_ϕ ) ( italic_v start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ⊙ italic_e start_POSTSUPERSCRIPT italic_ϵ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_Q start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ( italic_v start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_l ; italic_ϕ ) end_POSTSUPERSCRIPT + italic_T start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ( italic_v start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_l ; italic_ϕ ) ) ,
v𝑣\displaystyle vitalic_v :=assign\displaystyle:=:= vϵ(l;ϕ)2(U(y)eϵ02Qv(y,U(y),l;ϕ)+Tv(y,U(y),l;ϕ)).𝑣italic-ϵ𝑙italic-ϕ2direct-product𝑈𝑦superscript𝑒subscriptitalic-ϵ02subscript𝑄𝑣𝑦𝑈𝑦𝑙italic-ϕsubscript𝑇𝑣𝑦𝑈𝑦𝑙italic-ϕ\displaystyle v-\frac{\epsilon(l;\phi)}{2}\left(\nabla U(y)\odot e^{\frac{% \epsilon_{0}}{2}Q_{v}(y,\nabla U(y),l;\phi)}+T_{v}(y,\nabla U(y),l;\phi)\right).italic_v - divide start_ARG italic_ϵ ( italic_l ; italic_ϕ ) end_ARG start_ARG 2 end_ARG ( ∇ italic_U ( italic_y ) ⊙ italic_e start_POSTSUPERSCRIPT divide start_ARG italic_ϵ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG italic_Q start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ( italic_y , ∇ italic_U ( italic_y ) , italic_l ; italic_ϕ ) end_POSTSUPERSCRIPT + italic_T start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ( italic_y , ∇ italic_U ( italic_y ) , italic_l ; italic_ϕ ) ) .


最后,解码器输出 xxitalic_x 由以下公式给出:

x=yexp(ϵ0η(y;ϕ))U(y)+2exp(ϵ0η(y;ϕ))ξ𝑥𝑦subscriptitalic-ϵ0𝜂𝑦italic-ϕ𝑈𝑦2subscriptitalic-ϵ0𝜂𝑦italic-ϕ𝜉x=y-\exp(\epsilon_{0}\eta(y;\phi))\nabla U(y)+2\exp(\epsilon_{0}\eta(y;\phi))\xiitalic_x = italic_y - roman_exp ( italic_ϵ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_η ( italic_y ; italic_ϕ ) ) ∇ italic_U ( italic_y ) + 2 roman_exp ( italic_ϵ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_η ( italic_y ; italic_ϕ ) ) italic_ξ


其中 ξ\xiitalic_ξ 是按照标准高斯分布分布的。该方程可解释为布朗动力学 dy=U(y)dt+dWt\mathrm{d}y=-\nabla U(y)\,\mathrm{d}t+\mathrm{d}W_{t}roman_d italic_y = - ∇ italic_U ( italic_y ) roman_d italic_t + roman_d italic_W start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT 的有限步近似。在上述过程中, QvQ_{v}italic_Q start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPTTvT_{v}italic_T start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPTQyQ_{y}italic_Q start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPTTyT_{y}italic_T start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPTϵ\epsilonitalic_ϵη\etaitalic_η 都是可训练的神经网络。


基于 GHD 的解码器有两方面的主要优势。首先,它能有效利用能量函数的梯度信息,我们的实验表明,它能提高多模态分布的采样性能。其次,通过在经典哈密顿动力学中加入可训练的修正项和步骤,它只需几次迭代就能达到很好的解码密度。完整的解码过程请参见C


3.2损失函数


为了优化解码和编码过程的参数,我们可以最小化解码和编码过程提供的 (x,z[])(x,z_{[\cdot]})( italic_x , italic_z start_POSTSUBSCRIPT [ ⋅ ] end_POSTSUBSCRIPT ) 联合分布之间的 KL 分歧。由于数据处理不等式,该发散也是 DKL(pD(x)||π(x))D_{KL}(p_{D}(x)||\pi(x))italic_D start_POSTSUBSCRIPT italic_K italic_L end_POSTSUBSCRIPT ( italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x ) | | italic_π ( italic_x ) ) 的上界,就像(1 )中的上界一样。根据 B 中的推导,KL 发散可表示为

DKL(pD(x,z[])||pE(x,z[]))=(θ,ϕ)+logZ,D_{KL}\left(p_{D}(x,z_{[\cdot]})||p_{E}(x,z_{[\cdot]})\right)=\mathcal{L}(% \theta,\phi)+\log Z,italic_D start_POSTSUBSCRIPT italic_K italic_L end_POSTSUBSCRIPT ( italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x , italic_z start_POSTSUBSCRIPT [ ⋅ ] end_POSTSUBSCRIPT ) | | italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_x , italic_z start_POSTSUBSCRIPT [ ⋅ ] end_POSTSUBSCRIPT ) ) = caligraphic_L ( italic_θ , italic_ϕ ) + roman_log italic_Z , (5)

where 其中

(θ,ϕ)𝜃italic-ϕ\displaystyle\mathcal{L}(\theta,\phi)caligraphic_L ( italic_θ , italic_ϕ ) =\displaystyle== 𝔼pD[logpD(x|z0;ϕ)+U(x)]subscript𝔼subscript𝑝𝐷delimited-[]subscript𝑝𝐷conditional𝑥subscript𝑧0italic-ϕ𝑈𝑥\displaystyle\mathbb{E}_{p_{D}}\left[\log p_{D}(x|z_{0};\phi)+U(x)\right]blackboard_E start_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ; italic_ϕ ) + italic_U ( italic_x ) ] (6)
+0Tg(t)22𝔼pD[s(zt,x,t;θ)2+2zts(zt,x,t;θ)\displaystyle+\int_{0}^{T}\frac{g(t)^{2}}{2}\mathbb{E}_{p_{D}}\Bigg{[}\left\|s% \left(z_{t},x,t;\theta\right)\right\|^{2}+2\nabla_{z_{t}}\cdot s\left(z_{t},x,% t;\theta\right)+ ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT divide start_ARG italic_g ( italic_t ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 end_ARG blackboard_E start_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∥ italic_s ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x , italic_t ; italic_θ ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + 2 ∇ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⋅ italic_s ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x , italic_t ; italic_θ )
+ztlogpD(zt)2]dt.\displaystyle\hskip 170.71652pt+\left\|\nabla_{z_{t}}\log p_{D}(z_{t})\right\|% ^{2}\Bigg{]}\mathrm{d}t.+ ∥ ∇ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] roman_d italic_t .


通过利用重要性采样积分和 Hutchinson 估计器,我们可以得到 (θ,ϕ)\mathcal{L}(\theta,\phi)caligraphic_L ( italic_θ , italic_ϕ ) 的等效表达式,该表达式可以通过蒙特卡罗随机抽样从 pDp_{D}italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT 中高效、无偏地估计出来,而无需对 SDE 进行数值求解:

(θ,ϕ)𝜃italic-ϕ\displaystyle\mathcal{L}(\theta,\phi)caligraphic_L ( italic_θ , italic_ϕ ) =\displaystyle== 𝔼pD[logpD(x|z0;ϕ)+U(x)]subscript𝔼subscript𝑝𝐷delimited-[]subscript𝑝𝐷conditional𝑥subscript𝑧0italic-ϕ𝑈𝑥\displaystyle\mathbb{E}_{p_{D}}\left[\log p_{D}(x|z_{0};\phi)+U(x)\right]blackboard_E start_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ; italic_ϕ ) + italic_U ( italic_x ) ] (7)
+0Tλ(t)𝔼p(t)p(ϵ)pD(x,zt)[t(x,zt,ϵ;θ)]dt,superscriptsubscript0𝑇𝜆𝑡subscript𝔼𝑝𝑡𝑝italic-ϵsubscript𝑝𝐷𝑥subscript𝑧𝑡delimited-[]subscript𝑡𝑥subscript𝑧𝑡italic-ϵ𝜃differential-d𝑡\displaystyle+\int_{0}^{T}\lambda(t)\cdot\mathbb{E}_{p(t)p(\epsilon)p_{D}(x,z_% {t})}\left[\mathcal{L}_{t}(x,z_{t},\epsilon;\theta)\right]\mathrm{d}t,+ ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_λ ( italic_t ) ⋅ blackboard_E start_POSTSUBSCRIPT italic_p ( italic_t ) italic_p ( italic_ϵ ) italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x , italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ caligraphic_L start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x , italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_ϵ ; italic_θ ) ] roman_d italic_t ,

where 其中

t(x,zt,ϵ;θ)=s(zt,x,t;θ)2+2[ϵs(zt,x,t;θ)]ztϵ+ztlogpD(zt)2,subscript𝑡𝑥subscript𝑧𝑡italic-ϵ𝜃superscriptnorm𝑠subscript𝑧𝑡𝑥𝑡𝜃22delimited-[]superscriptitalic-ϵtop𝑠subscript𝑧𝑡𝑥𝑡𝜃subscript𝑧𝑡italic-ϵsuperscriptnormsubscriptsubscript𝑧𝑡subscript𝑝𝐷subscript𝑧𝑡2\mathcal{L}_{t}(x,z_{t},\epsilon;\theta)=\left\|s\left(z_{t},x,t;\theta\right)% \right\|^{2}+2\frac{\partial\left[\epsilon^{\top}s\left(z_{t},x,t;\theta\right% )\right]}{\partial z_{t}}\epsilon+\left\|\nabla_{z_{t}}\log p_{D}(z_{t})\right% \|^{2},caligraphic_L start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x , italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_ϵ ; italic_θ ) = ∥ italic_s ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x , italic_t ; italic_θ ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + 2 divide start_ARG ∂ [ italic_ϵ start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_s ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x , italic_t ; italic_θ ) ] end_ARG start_ARG ∂ italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG italic_ϵ + ∥ ∇ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ,


p(ϵ)p(\epsilon)italic_p ( italic_ϵ ) 是具有 𝔼[ϵ]=0,cov(ϵ)=I\mathbb{E}[\epsilon]=0,\mathrm{cov}(\epsilon)=Iblackboard_E [ italic_ϵ ] = 0 , roman_cov ( italic_ϵ ) = italic_I 的拉德马赫分布, p(t)p(t)italic_p ( italic_t )t[0,T]t\in[0,T]italic_t ∈ [ 0 , italic_T ] 的提议分布,以及加权函数 λ(t)=g(t)22p(t)\lambda(t)=\frac{g(t)^{2}}{2p(t)}italic_λ ( italic_t ) = divide start_ARG italic_g ( italic_t ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 italic_p ( italic_t ) end_ARG 。然后,可以使用随机梯度下降法最小化损失函数(7 )来训练 EDG 中涉及的所有神经网络。


在我们的实验中,为了减小蒙特卡罗损失函数近似中固有的方差,我们采用了 [49] 中的策略,根据历史缓冲区动态调整时间建议分布 pD(t)p_{D}(t)italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_t ) ,该缓冲区存储了与 ttitalic_t 相关的最近损失值的 r.h.s of (7).更多详情,请参阅 D


3.3样本重新加权


训练完成后,我们可以使用解码器 pD(x|z0)p_{D}(x|z_{0})italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) 生成样本,并计算目标分布 π(x)\pi(x)italic_π ( italic_x ) 的各种统计数据。例如,对于感兴趣的数量 O:dO:\mathbb{R}^{d}\to\mathbb{R}italic_O : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT → blackboard_R ,我们可以从 pD(z0)pD(x|z0)p_{D}(z_{0})p_{D}(x|z_{0})italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) 中提取 NNitalic_N 增强样本 {(xn,z0n)}n=1N\{(x^{n},z_{0}^{n})\}_{n=1}^{N}{ ( italic_x start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ) } start_POSTSUBSCRIPT italic_n = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ,并按如下方法估计期望值 𝔼π(x)[O(x)]\mathbb{E}_{\pi(x)}[O(x)]blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x ) end_POSTSUBSCRIPT [ italic_O ( italic_x ) ]

𝔼π(x)[O(x)]1NnO(xn).subscript𝔼𝜋𝑥delimited-[]𝑂𝑥1𝑁subscript𝑛𝑂superscript𝑥𝑛\mathbb{E}_{\pi(x)}[O(x)]\approx\frac{1}{N}\sum_{n}O(x^{n}).blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x ) end_POSTSUBSCRIPT [ italic_O ( italic_x ) ] ≈ divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT italic_O ( italic_x start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ) .


然而,由于模型误差,这种估计可能会出现系统性偏差。为了解决这个问题,我们可以使用重要性采样,将 pD(x,z0)=pD(z0)pD(x|z0)p_{D}(x,z_{0})=p_{D}(z_{0})p_{D}(x|z_{0})italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) 作为建议分布,将 pE(x,z0)=π(x)pE(z0|x)p_{E}(x,z_{0})=\pi(x)p_{E}(z_{0}|x)italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_x , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = italic_π ( italic_x ) italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x ) 作为增强目标分布。然后,我们可以为解码器生成的每个样本 (x,z0)(x,z_{0})( italic_x , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) 分配一个非规范化权重:

w(x,z0)=exp(U(x))pE(z0|x)pD(z0)pD(x|z0)π(x)pE(z0|x)pD(z0)pD(x|z0)𝑤𝑥subscript𝑧0𝑈𝑥subscript𝑝𝐸conditionalsubscript𝑧0𝑥subscript𝑝𝐷subscript𝑧0subscript𝑝𝐷conditional𝑥subscript𝑧0proportional-to𝜋𝑥subscript𝑝𝐸conditionalsubscript𝑧0𝑥subscript𝑝𝐷subscript𝑧0subscript𝑝𝐷conditional𝑥subscript𝑧0w(x,z_{0})=\frac{\exp(-U(x))p_{E}(z_{0}|x)}{p_{D}(z_{0})p_{D}(x|z_{0})}\propto% \frac{\pi(x)p_{E}(z_{0}|x)}{p_{D}(z_{0})p_{D}(x|z_{0})}italic_w ( italic_x , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = divide start_ARG roman_exp ( - italic_U ( italic_x ) ) italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x ) end_ARG start_ARG italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_ARG ∝ divide start_ARG italic_π ( italic_x ) italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x ) end_ARG start_ARG italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_ARG (8)


并得到 𝔼π(x)[O(x)]\mathbb{E}_{\pi(x)}[O(x)]blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x ) end_POSTSUBSCRIPT [ italic_O ( italic_x ) ] 的一致估计值:

𝔼π(x)[O(x)]nw(xn,z0n)O(xn)nw(xn,z0n),subscript𝔼𝜋𝑥delimited-[]𝑂𝑥subscript𝑛𝑤superscript𝑥𝑛superscriptsubscript𝑧0𝑛𝑂superscript𝑥𝑛subscript𝑛𝑤superscript𝑥𝑛superscriptsubscript𝑧0𝑛\mathbb{E}_{\pi(x)}[O(x)]\approx\frac{\sum_{n}w(x^{n},z_{0}^{n})O(x^{n})}{\sum% _{n}w(x^{n},z_{0}^{n})},blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x ) end_POSTSUBSCRIPT [ italic_O ( italic_x ) ] ≈ divide start_ARG ∑ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT italic_w ( italic_x start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ) italic_O ( italic_x start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT italic_w ( italic_x start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ) end_ARG ,


其中估计误差随着 NN\to\inftyitalic_N → ∞ 的变化趋近于零。


此外,权重函数 wwitalic_w 还可用于估计归一化常数 ZZitalic_Z ,这在许多应用中都是一项关键任务,例如统计学中的贝叶斯模型选择和统计物理学中的自由能估计。根据 (1), 我们可以得出:

logZ𝔼pD(x,z0)[logw(x,z0)].𝑍subscript𝔼subscript𝑝𝐷𝑥subscript𝑧0delimited-[]𝑤𝑥subscript𝑧0\log Z\geq-\mathbb{E}_{p_{D}(x,z_{0})}\left[\log w(x,z_{0})\right].roman_log italic_Z ≥ - blackboard_E start_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ roman_log italic_w ( italic_x , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ] .


在这里,下限也可以使用解码器的样本来估算,当 pD(x,z0)=pE(x,z0)p_{D}(x,z_{0})=p_{E}(x,z_{0})italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_x , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) 时,就能达到这个下限的严格程度。


上述计算的主要困难在于计算权重函数时,边际解码器密度 pE(z0|x)p_{E}(z_{0}|x)italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x ) 的难易程度。为了克服这一困难,我们根据 [37] 中第 4.3 节的结论,构建了以下概率流常微分方程(ODE):

dzt=(f(zt,t)12g(t)2s(zt,x,t;θ))dt,dsubscript𝑧𝑡𝑓subscript𝑧𝑡𝑡12𝑔superscript𝑡2𝑠subscript𝑧𝑡𝑥𝑡𝜃d𝑡\mathrm{d}z_{t}=\left(f(z_{t},t)-\frac{1}{2}g(t)^{2}s(z_{t},x,t;\theta)\right)% \mathrm{d}t,roman_d italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = ( italic_f ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_g ( italic_t ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_s ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x , italic_t ; italic_θ ) ) roman_d italic_t ,


与边界条件 zTpE(zT)z_{T}\sim p_{E}(z_{T})italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ∼ italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) 。如果 s(zt,x,t;θ)s(z_{t},x,t;\theta)italic_s ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x , italic_t ; italic_θ ) 在训练后能准确逼近得分函数 ztlogpE(zt|x)\nabla_{z_{t}}\log p_{E}(z_{t}|x)∇ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x ) ,那么 ODE 给出的 zt|xz_{t}|xitalic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x 的条件分布将与每个 t[0,1]t\in[0,1]italic_t ∈ [ 0 , 1 ] 的编码过程相匹配。因此,我们可以使用神经 ODE 方法 [50] 高效计算 pE(z0|x)p_{E}(z_{0}|x)italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x )


4 实验


我们对各种能量函数的 EDG 进行了实证评估。首先,我们介绍了从一组二维分布中获得的结果。接着,我们展示了 EDG 在贝叶斯逻辑回归中的表现。最后,我们将 EDG 应用于一个伊辛模型。E 中提供了所有实验细节。此外,我们还进行了一项消融研究,以验证 EDG 中每个模块的有效性。更多信息请参阅F


为了证明我们的模型的优越性,我们将 EDG 与以下抽样方法进行了比较:


  • 1.Vanilla Hamiltonian Monte Carlo 法[ 8] ,简称 V-HMC。

  • 2.L2HMC [ 13] 是一种基于 GHD 的 MCMC 方法,具有可训练的提议分布模型。

  • 3.Boltzmann Generator(BG)[ 18],这是一种 VI 方法,使用 RealNVP 对代理分布进行建模[ 51]。

  • 4.神经重正化组(NeuralRG)[ 17],一种类似于 BG 的方法,专为伊辛模型设计。在本节中,NeuralRG 仅用于伊辛模型的实验。

  • 5.路径积分采样器(PIS)[ 39],这是一种通过对 SDE 进行数值模拟的基于扩散的采样模型。

表 1:每个生成器生成的样本与参考样本之间的最大平均差异(MMD)。有关差异计算的详细信息,请参阅 E
Mog2 Mog2(i) Mog6 Mog9 莫格9 Ring  Ring5 环5
V-HMC 0.010.01\mathbf{0.01}bold_0.01 1.561.561.561.56 0.020.020.020.02 0.040.040.040.04 0.010.01\mathbf{0.01}bold_0.01 0.010.01\mathbf{0.01}bold_0.01
L2HMC 0.040.040.040.04 0.940.940.940.94 0.010.01\mathbf{0.01}bold_0.01 0.030.030.030.03 0.020.020.020.02 0.010.01\mathbf{0.01}bold_0.01
BG 1.901.901.901.90 1.631.631.631.63 2.642.642.642.64 0.070.070.070.07 0.050.050.050.05 0.180.180.180.18
PIS 0.010.01\mathbf{0.01}bold_0.01 1.661.661.661.66 0.010.01\mathbf{0.01}bold_0.01 0.420.420.420.42 0.010.01\mathbf{0.01}bold_0.01 0.780.780.780.78
EDG 0.010.01\mathbf{0.01}bold_0.01 0.500.50\mathbf{0.50}bold_0.50 0.010.01\mathbf{0.01}bold_0.01 0.020.02\mathbf{0.02}bold_0.02 0.010.01\mathbf{0.01}bold_0.01 0.020.020.020.02

MoG2            MoG2(i)          MoG6            MoG9             Ring             Ring5           

Ref 参考文献

Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption

V-HMC

Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption

L2HMC

Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption

BG

Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption

PIS

Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption

EDG

Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption

图 2:二维能量函数的密度图。关于参考样本的生成,请参阅 E 。我们为每种方法生成 500,000500,000500 , 000 样本,并绘制直方图。


二维能量函数首先,我们在几个合成的二维能量函数上比较了我们的模型和其他模型:MoG2(i)(具有相同 σ2=0.5\sigma^{2}=0.5italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = 0.5 或不同方差 σ12=1.5,σ22=0.3\sigma_{1}^{2}=1.5,\sigma_{2}^{2}=0.3italic_σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = 1.5 , italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = 0.3 的两个各向同性高斯的混合物,中心点之间的距离为 101010 )、MoG6(具有方差 σ2=0.1\sigma^{2}=0.1italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = 0.1 的六个各向同性高斯的混合物)、MoG9(方差为 σ2=0.3\sigma^{2}=0.3italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = 0.3 的九个各向同性高斯的混集)、Ring、Ring5(能量函数见[12])。我们在图 2 中展示了样本直方图,以供目测,表 1 总结了采样误差。如图所示,与其他方法相比,EDG 提供的样本质量更高。为了阐明 EDG 中每个组件的功能,我们将我们的模型与 F 中的 vanilla VAE 进行了比较。


贝叶斯逻辑回归 在随后的实验中,我们将重点介绍 EDG 在贝叶斯逻辑回归中的功效,尤其是在处理位于高维空间中的后验分布时。在这种情况下,我们处理的是一个二元分类问题,其标签为 L={0,1}L=\{0,1\}italic_L = { 0 , 1 } ,高维特征为 DDitalic_D 。分类器的输出定义为

p(L=1|D,x)=softmax(wD+b),𝑝𝐿conditional1𝐷𝑥softmaxsuperscript𝑤top𝐷𝑏p(L=1|D,x)=\mathrm{softmax}(w^{\top}D+b),italic_p ( italic_L = 1 | italic_D , italic_x ) = roman_softmax ( italic_w start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_D + italic_b ) ,


其中 x=(w,b)x=(w,b)italic_x = ( italic_w , italic_b ) 。我们的目标是从后验分布中抽取样本 x1,,xNx^{1},\ldots,x^{N}italic_x start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , … , italic_x start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT

π(x)p(x)(L,D)𝒟trainp(L|D,x)proportional-to𝜋𝑥𝑝𝑥subscriptproduct𝐿𝐷subscript𝒟train𝑝conditional𝐿𝐷𝑥\pi(x)\propto p(x)\prod_{(L,D)\in\mathcal{D}_{\mathrm{train}}}p(L|D,x)italic_π ( italic_x ) ∝ italic_p ( italic_x ) ∏ start_POSTSUBSCRIPT ( italic_L , italic_D ) ∈ caligraphic_D start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_p ( italic_L | italic_D , italic_x )


基于训练集 𝒟train\mathcal{D}_{\mathrm{train}}caligraphic_D start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT ,其中先验分布 p(x)p(x)italic_p ( italic_x ) 是标准高斯分布。然后,对于给定的 DDitalic_D ,条件分布 p(L|D,𝒟train)p(L|D,\mathcal{D}_{\mathrm{train}})italic_p ( italic_L | italic_D , caligraphic_D start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT ) 可以近似为 np(L|D,xn)\sum_{n}p(L|D,x^{n})∑ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT italic_p ( italic_L | italic_D , italic_x start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ) 。我们在三个数据集上进行了实验:[52] ,评估测试子集的准确率 (ACC) 和曲线下面积 (AUC)。值得注意的是,如表 2 所示,EDG 始终保持着最高的准确率和 AUC 性能。


表 2:贝叶斯逻辑回归任务的分类准确率和 AUC 结果。实验采用一致的训练和测试数据分区,其中 HMC 步长设置为 0.010.010.01 。平均准确率和 AUC 值以及各自的标准偏差是在所有数据集的 323232 独立实验中计算得出的。
AU GE HE
Acc Auc Acc Auc Acc Auc
V-HMC 82.97±1.94plus-or-minus82.971.9482.97\pm 1.9482.97 ± 1.94 90.88±0.83plus-or-minus90.880.8390.88\pm 0.8390.88 ± 0.83 78.52±0.48plus-or-minus78.520.4878.52\pm 0.4878.52 ± 0.48 77.67±0.28plus-or-minus77.670.2877.67\pm 0.2877.67 ± 0.28 86.75±1.63plus-or-minus86.751.6386.75\pm 1.6386.75 ± 1.63 93.35±0.76plus-or-minus93.350.7693.35\pm 0.7693.35 ± 0.76
L2HMC 73.26±1.56plus-or-minus73.261.5673.26\pm 1.5673.26 ± 1.56 79.69±3.65plus-or-minus79.693.6579.69\pm 3.6579.69 ± 3.65 62.02±4.19plus-or-minus62.024.1962.02\pm 4.1962.02 ± 4.19 60.23±5.10plus-or-minus60.235.1060.23\pm 5.1060.23 ± 5.10 82.23±2.81plus-or-minus82.232.8182.23\pm 2.8182.23 ± 2.81 90.48±0.51plus-or-minus90.480.5190.48\pm 0.5190.48 ± 0.51
BG 82.99±1.18plus-or-minus82.991.1882.99\pm 1.1882.99 ± 1.18 91.23±0.67plus-or-minus91.230.6791.23\pm 0.6791.23 ± 0.67 78.14±1.44plus-or-minus78.141.4478.14\pm 1.4478.14 ± 1.44 77.59±0.73plus-or-minus77.590.7377.59\pm 0.7377.59 ± 0.73 86.75±1.99plus-or-minus86.751.9986.75\pm 1.9986.75 ± 1.99 93.44±0.39plus-or-minus93.440.3993.44\pm 0.3993.44 ± 0.39
PIS 81.64±2.63plus-or-minus81.642.6381.64\pm 2.6381.64 ± 2.63 91.23±0.67plus-or-minus91.230.6791.23\pm 0.6791.23 ± 0.67 71.90±3.17plus-or-minus71.903.1771.90\pm 3.1771.90 ± 3.17 71.67±4.52plus-or-minus71.674.5271.67\pm 4.5271.67 ± 4.52 83.24±3.95plus-or-minus83.243.9583.24\pm 3.9583.24 ± 3.95 91.68±2.78plus-or-minus91.682.7891.68\pm 2.7891.68 ± 2.78
EDG 84.96±1.67plus-or-minus84.961.67\mathbf{84.96\pm 1.67}bold_84.96 ± bold_1.67 92.82±0.69plus-or-minus92.820.69\mathbf{92.82\pm 0.69}bold_92.82 ± bold_0.69 79.40±1.74plus-or-minus79.401.74\mathbf{79.40\pm 1.74}bold_79.40 ± bold_1.74 82.79±1.46plus-or-minus82.791.46\mathbf{82.79\pm 1.46}bold_82.79 ± bold_1.46 88.02±3.90plus-or-minus88.023.90\mathbf{88.02\pm 3.90}bold_88.02 ± bold_3.90 95.10±1.23plus-or-minus95.101.23\mathbf{95.10\pm 1.23}bold_95.10 ± bold_1.23


我们将分析扩展到由 581 012 个数据点和 54 个特征组成的二元 Covertype 数据集。分类器参数的后验遵循分层贝叶斯模型(见 [27] 的第 5 章),其中 xxitalic_x 表示分类器参数和分层贝叶斯模型中超参数的组合。为了提高计算效率,在 BG 和 EDG 中, logπ(x)\log\pi(x)roman_log italic_π ( italic_x ) 在训练过程中无偏近似为

logπ(x)logp(x)+|𝒟train|||(L,D)logp(L|D,x),𝜋𝑥𝑝𝑥subscript𝒟trainsubscript𝐿𝐷𝑝conditional𝐿𝐷𝑥\log\pi(x)\approx\log p(x)+\frac{|\mathcal{D}_{\mathrm{train}}|}{|\mathcal{B}|% }\sum_{(L,D)\in\mathcal{B}}\log p(L|D,x),roman_log italic_π ( italic_x ) ≈ roman_log italic_p ( italic_x ) + divide start_ARG | caligraphic_D start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT | end_ARG start_ARG | caligraphic_B | end_ARG ∑ start_POSTSUBSCRIPT ( italic_L , italic_D ) ∈ caligraphic_B end_POSTSUBSCRIPT roman_log italic_p ( italic_L | italic_D , italic_x ) ,


其中 \mathcal{B}caligraphic_B 为随机小批量。对于 V-HMC 和 L2HMC,计算的是精确的后验密度。表 3 中的结果表明,EDG 始终优于其他方法。


表 3:Coverstype 测试数据集的分类准确率。报告值代表 32 次独立实验的平均准确率和标准偏差。
V-HMC L2HMC BG PIS EDG
Acc 49.88±3.32plus-or-minus49.883.3249.88\pm 3.3249.88 ± 3.32 51.51±3.46plus-or-minus51.513.4651.51\pm 3.4651.51 ± 3.46 50.75±3.78plus-or-minus50.753.7850.75\pm 3.7850.75 ± 3.78 50.59±2.94plus-or-minus50.592.9450.59\pm 2.9450.59 ± 2.94 70.13±2.13plus-or-minus70.132.13\mathbf{70.13\pm 2.13}bold_70.13 ± bold_2.13

表 4:通过第 3.3 节中描述的方法,获得了维数为 256 的二维等化模型中 logZIsing\log Z_{\mathrm{Ising}}roman_log italic_Z start_POSTSUBSCRIPT roman_Ising end_POSTSUBSCRIPT 的估计值( 16×1616\times 1616 × 16 )。我们利用 n=256n=256italic_n = 256 的批量大小来估计平均值,并应用中心极限定理计算出统计量平均值的标准偏差为 std/n\mathrm{std}/\sqrt{n}roman_std / square-root start_ARG italic_n end_ARG
logZIsingsubscript𝑍Ising\log Z_{\mathrm{Ising}}roman_log italic_Z start_POSTSUBSCRIPT roman_Ising end_POSTSUBSCRIPT NeuralRG 神经网络 PIS EDG
T=2.0𝑇2.0T=2.0italic_T = 2.0 260±0.13plus-or-minus2600.13260\pm 0.13260 ± 0.13 210±0.43plus-or-minus2100.43210\pm 0.43210 ± 0.43 𝟐𝟕𝟎±0.18plus-or-minus2700.18\mathbf{270\pm 0.18}bold_270 ± bold_0.18
T=2.1𝑇2.1T=2.1italic_T = 2.1 250±0.14plus-or-minus2500.14250\pm 0.14250 ± 0.14 208±0.41plus-or-minus2080.41208\pm 0.41208 ± 0.41 𝟐𝟓𝟓±0.19plus-or-minus2550.19\mathbf{255\pm 0.19}bold_255 ± bold_0.19
T=2.2𝑇2.2T=2.2italic_T = 2.2 239±0.16plus-or-minus2390.16239\pm 0.16239 ± 0.16 210±0.39plus-or-minus2100.39210\pm 0.39210 ± 0.39 𝟐𝟓𝟐±0.17plus-or-minus2520.17\mathbf{252\pm 0.17}bold_252 ± bold_0.17
T=2.3𝑇2.3T=2.3italic_T = 2.3 231±0.15plus-or-minus2310.15231\pm 0.15231 ± 0.15 214±0.37plus-or-minus2140.37214\pm 0.37214 ± 0.37 𝟐𝟑𝟑±0.17plus-or-minus2330.17\mathbf{233\pm 0.17}bold_233 ± bold_0.17
T=2.4𝑇2.4T=2.4italic_T = 2.4 𝟐𝟐𝟓±0.17plus-or-minus2250.17\mathbf{225\pm 0.17}bold_225 ± bold_0.17 212±0.37plus-or-minus2120.37212\pm 0.37212 ± 0.37 𝟐𝟐𝟓±0.15plus-or-minus2250.15\mathbf{225\pm 0.15}bold_225 ± bold_0.15
T=2.5𝑇2.5T=2.5italic_T = 2.5 219±0.17plus-or-minus2190.17219\pm 0.17219 ± 0.17 202±0.37plus-or-minus2020.37202\pm 0.37202 ± 0.37 𝟐𝟐𝟏±0.14plus-or-minus2210.14\mathbf{221\pm 0.14}bold_221 ± bold_0.14
T=2.6𝑇2.6T=2.6italic_T = 2.6 𝟐𝟏𝟔±0.18plus-or-minus2160.18\mathbf{216\pm 0.18}bold_216 ± bold_0.18 181±0.40plus-or-minus1810.40181\pm 0.40181 ± 0.40 214±0.14plus-or-minus2140.14214\pm 0.14214 ± 0.14
T=2.7𝑇2.7T=2.7italic_T = 2.7 𝟐𝟏𝟐±0.18plus-or-minus2120.18\mathbf{212\pm 0.18}bold_212 ± bold_0.18 189±0.36plus-or-minus1890.36189\pm 0.36189 ± 0.36 𝟐𝟏𝟐±0.14plus-or-minus2120.14\mathbf{212\pm 0.14}bold_212 ± bold_0.14


伊辛模型 最后,我们验证了 EDG 在二维伊辛模型 [17] 上的性能,该模型是统计力学中的铁磁数学模型。为确保物理变量的连续性,我们采用连续松弛技巧 [53] 将离散变量转换为具有目标分布的连续辅助变量:

π(𝒙)=exp(12𝒙T(K(T)+αI)1𝒙)×i=1Ncosh(xi),𝜋𝒙12superscript𝒙𝑇superscript𝐾𝑇𝛼𝐼1𝒙superscriptsubscriptproduct𝑖1𝑁subscript𝑥𝑖\pi(\boldsymbol{x})=\exp\left(-\frac{1}{2}\boldsymbol{x}^{T}\left(K(T)+\alpha I% \right)^{-1}\boldsymbol{x}\right)\times\prod_{i=1}^{N}\cosh\left(x_{i}\right),italic_π ( bold_italic_x ) = roman_exp ( - divide start_ARG 1 end_ARG start_ARG 2 end_ARG bold_italic_x start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( italic_K ( italic_T ) + italic_α italic_I ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_x ) × ∏ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT roman_cosh ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ,


其中 KKitalic_K 是一个与温度 TTitalic_T 有关的 N×NN\times Nitalic_N × italic_N 对称矩阵, α\alphaitalic_α 是一个保证 K+αIK+\alpha Iitalic_K + italic_α italic_I 为正值的常数。对于相应的离散伊辛变量 𝒔={1,1}N\boldsymbol{s}=\{1,-1\}^{\otimes N}bold_italic_s = { 1 , - 1 } start_POSTSUPERSCRIPT ⊗ italic_N end_POSTSUPERSCRIPT ,可以根据 π(𝒔|𝒙)=i(1+e2sixi)1\pi(\boldsymbol{s}|\boldsymbol{x})=\prod_{i}\left(1+e^{-2s_{i}x_{i}}\right)^{-1}italic_π ( bold_italic_s | bold_italic_x ) = ∏ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( 1 + italic_e start_POSTSUPERSCRIPT - 2 italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT 直接得到离散样本。当没有外部磁场,每个自旋只能与相邻的自旋相互作用时, KKitalic_K 定义为 <ij>sisj/T\sum_{<ij>}s_{i}s_{j}/T∑ start_POSTSUBSCRIPT < italic_i italic_j > end_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT / italic_T ,近邻和为 <ij><ij>< italic_i italic_j > 。因此,连续弛豫系统的归一化常数由 logZ=logZIsing +12lndet(K+αI)N2[ln(2/π)α]\log Z=\log Z_{\text{Ising }}+\frac{1}{2}\ln\operatorname{det}(K+\alpha I)-% \frac{N}{2}[\ln(2/\pi)-\alpha]roman_log italic_Z = roman_log italic_Z start_POSTSUBSCRIPT Ising end_POSTSUBSCRIPT + divide start_ARG 1 end_ARG start_ARG 2 end_ARG roman_ln roman_det ( italic_K + italic_α italic_I ) - divide start_ARG italic_N end_ARG start_ARG 2 end_ARG [ roman_ln ( 2 / italic_π ) - italic_α ] [17] 给出。此外,使用第 3.3 节中描述的方法,我们为 NeuralRG、PIS 和 EDG 生成的样本提供了不同温度下 logZIsing\log Z_{\text{Ising}}roman_log italic_Z start_POSTSUBSCRIPT Ising end_POSTSUBSCRIPT 的下限估计值。由于这些是下限估计值,因此数值越大表示结果越精确。如表 4 所示,在大多数温度范围内,EDG 提供了最准确的 logZ\log Zroman_log italic_Z 估计值。图 3 显示了不同温度下生成的状态。

Refer to caption

图 3:维数为 256 的 EDG 在 T=2.0T=2.0italic_T = 2.0T=2.7T=2.7italic_T = 2.7 的不同温度下生成的状态( 16×1616\times 1616 × 16 ),其中潜变量 z0z_{0}italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT 保持不变。随着温度的升高,模型的状态逐渐趋于无序。

5 Conclusion
5 结论


总之,我们的工作结合了基于 VI 和扩散方法的原理,将 EDG 介绍为一种创新而有效的采样方法。EDG 从 VAE 中汲取灵感,擅长从错综复杂的玻尔兹曼分布中高效生成样本。利用扩散模型的表现力,我们的方法无需对常微分方程或随机微分方程进行数值求解,就能准确估计 KL 发散。经验实验验证了 EDG 的卓越采样性能。


考虑到其强大的生成能力和理论上不受限制的网络设计,仍有进一步探索的空间。我们可以针对不同的任务设计特定的网络结构,找到最精细的网络结构。尽管如此,它还是首次尝试在扩散模型的辅助下设计单次生成器。


今后,我们的主要重点将是扩展 EDG 的应用,为蛋白质等大规模物理和化学系统构建生成模型 [54, 55].

 致谢


第一和第三作者受中国国家自然科学基金资助(资助号:12171367)。第二作者得到中国国家自然科学基金(批准号:92270115、12071301)、上海市科委(批准号:20JC1412500)和河南省科学院的资助。最后一位作者得到中国国家自然科学基金(批准号:12288201)、中国科学院战略性先导科技专项(批准号:XDA25010404)、国家重点研发计划(2020YFA0712000)、中科院青年创新促进会和河南省科学院的资助。


附录 A zTz_{T}italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPTxxitalic_x 之间独立性的证明


在解码过程中,如果 zTz_{T}italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPTz0z_{0}italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT 无关,那么我们有

pD(zT|x)subscript𝑝𝐷conditionalsubscript𝑧𝑇𝑥\displaystyle p_{D}(z_{T}|x)italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT | italic_x ) =\displaystyle== pD(zT|z0)pD(z0|x)dz0subscript𝑝𝐷conditionalsubscript𝑧𝑇subscript𝑧0subscript𝑝𝐷conditionalsubscript𝑧0𝑥differential-dsubscript𝑧0\displaystyle\int p_{D}(z_{T}|z_{0})p_{D}(z_{0}|x)\mathrm{d}z_{0}∫ italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x ) roman_d italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT
=\displaystyle== pD(zT)pD(z0|x)dz0subscript𝑝𝐷subscript𝑧𝑇subscript𝑝𝐷conditionalsubscript𝑧0𝑥differential-dsubscript𝑧0\displaystyle p_{D}(z_{T})\cdot\int p_{D}(z_{0}|x)\mathrm{d}z_{0}italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) ⋅ ∫ italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x ) roman_d italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT
=\displaystyle== pD(zT).subscript𝑝𝐷subscript𝑧𝑇\displaystyle p_{D}(z_{T}).italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) .

Appendix B Proof of (5)


对于极小的滞后时间 τ\tauitalic_τ ,解码过程中 z[]z_{[\cdot]}italic_z start_POSTSUBSCRIPT [ ⋅ ] end_POSTSUBSCRIPT 的欧拉-马鲁山离散化提供了

ztsubscript𝑧𝑡\displaystyle z_{t}italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT =\displaystyle== ztτ+f(ztτ,tτ)τ+τg(tτ)utτ,subscript𝑧𝑡𝜏𝑓subscript𝑧𝑡𝜏𝑡𝜏𝜏𝜏𝑔𝑡𝜏subscript𝑢𝑡𝜏\displaystyle z_{t-\tau}+f(z_{t-\tau},t-\tau)\tau+\sqrt{\tau}g(t-\tau)u_{t-% \tau},italic_z start_POSTSUBSCRIPT italic_t - italic_τ end_POSTSUBSCRIPT + italic_f ( italic_z start_POSTSUBSCRIPT italic_t - italic_τ end_POSTSUBSCRIPT , italic_t - italic_τ ) italic_τ + square-root start_ARG italic_τ end_ARG italic_g ( italic_t - italic_τ ) italic_u start_POSTSUBSCRIPT italic_t - italic_τ end_POSTSUBSCRIPT , (9)
ztτsubscript𝑧𝑡𝜏\displaystyle z_{t-\tau}italic_z start_POSTSUBSCRIPT italic_t - italic_τ end_POSTSUBSCRIPT =\displaystyle== zt(f(zt,t)g(t)2logpD(zt))τ+τg(t)u¯t,subscript𝑧𝑡𝑓subscript𝑧𝑡𝑡𝑔superscript𝑡2subscript𝑝𝐷subscript𝑧𝑡𝜏𝜏𝑔𝑡subscript¯𝑢𝑡\displaystyle z_{t}-\left(f(z_{t},t)-g(t)^{2}\nabla\log p_{D}(z_{t})\right)% \tau+\sqrt{\tau}g(t)\bar{u}_{t},italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - ( italic_f ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - italic_g ( italic_t ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) italic_τ + square-root start_ARG italic_τ end_ARG italic_g ( italic_t ) over¯ start_ARG italic_u end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , (10)


其中, utτ,u¯t𝒩(|0,I)u_{t-\tau},\bar{u}_{t}\sim\mathcal{N}(\cdot|0,I)italic_u start_POSTSUBSCRIPT italic_t - italic_τ end_POSTSUBSCRIPT , over¯ start_ARG italic_u end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∼ caligraphic_N ( ⋅ | 0 , italic_I ) 。重要的是, utτu_{t-\tau}italic_u start_POSTSUBSCRIPT italic_t - italic_τ end_POSTSUBSCRIPT{zt|ttτ}\{z_{t^{\prime}}|t^{\prime}\leq t-\tau\}{ italic_z start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT | italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≤ italic_t - italic_τ }xxitalic_x 无关,而 u¯t\bar{u}_{t}over¯ start_ARG italic_u end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT{zt|tt}\{z_{t^{\prime}}|t^{\prime}\geq t\}{ italic_z start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT | italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≥ italic_t } 无关。通过对 pE(z[]|x)p_{E}(z_{[\cdot]}|x)italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT [ ⋅ ] end_POSTSUBSCRIPT | italic_x ) 的欧拉-马鲁山近似计算,可以得出

ztτ=zt(f(zt,t)g(t)2s(zt,x,t))τ+τg(t)vtsubscript𝑧𝑡𝜏subscript𝑧𝑡𝑓subscript𝑧𝑡𝑡𝑔superscript𝑡2𝑠subscript𝑧𝑡𝑥𝑡𝜏𝜏𝑔𝑡subscript𝑣𝑡z_{t-\tau}=z_{t}-\left(f(z_{t},t)-g(t)^{2}s\left(z_{t},x,t\right)\right)\tau+% \sqrt{\tau}g(t)v_{t}italic_z start_POSTSUBSCRIPT italic_t - italic_τ end_POSTSUBSCRIPT = italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - ( italic_f ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - italic_g ( italic_t ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_s ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x , italic_t ) ) italic_τ + square-root start_ARG italic_τ end_ARG italic_g ( italic_t ) italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT (11)

with vt𝒩(|0,I)v_{t}\sim\mathcal{N}(\cdot|0,I)italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∼ caligraphic_N ( ⋅ | 0 , italic_I ). vt𝒩(|0,I)v_{t}\sim\mathcal{N}(\cdot|0,I)italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∼ caligraphic_N ( ⋅ | 0 , italic_I )


根据(10, 11 ),我们可以得到

𝔼pD[logpD(ztτ|zt)]subscript𝔼subscript𝑝𝐷delimited-[]subscript𝑝𝐷conditionalsubscript𝑧𝑡𝜏subscript𝑧𝑡\displaystyle\mathbb{E}_{p_{D}}\left[\log p_{D}(z_{t-\tau}|z_{t})\right]blackboard_E start_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_t - italic_τ end_POSTSUBSCRIPT | italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ] =\displaystyle== d2log2πd2logτg(t)2𝑑22𝜋𝑑2𝜏𝑔superscript𝑡2\displaystyle-\frac{d}{2}\log 2\pi-\frac{d}{2}\log\tau g(t)^{2}- divide start_ARG italic_d end_ARG start_ARG 2 end_ARG roman_log 2 italic_π - divide start_ARG italic_d end_ARG start_ARG 2 end_ARG roman_log italic_τ italic_g ( italic_t ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (12)
12τ1g(t)2𝔼pD[ztτzt+(f