基于能量的波尔兹曼分布高效采样扩散发生器 Report issue for preceding element
Yan Wanga , Ling Guob , Hao Wu*c , Tao Zhoud Yan Wanga , Ling Guob , Hao Wu*c , Tao Zhoud
Report issue for preceding element
Abstract 摘要 Report issue for preceding element
波尔兹曼分布采样,尤其是与高维和复杂能量函数相关的采样,在许多领域都是一项重大挑战。在这项工作中,我们提出了基于能量的扩散发生器(EDG),这是一种整合了变异自动编码器和扩散模型思想的新方法。EDG 利用解码器将潜变量从简单分布转换为近似目标波尔兹曼分布的样本,而基于扩散的编码器则在训练过程中提供对库尔贝-莱布勒发散的精确估计。值得注意的是,EDG 无需模拟,因此在训练过程中无需求解常微分方程或随机微分方程。此外,通过消除解码器中的双射性等约束条件,EDG 允许进行灵活的网络设计。通过实证评估,我们证明了 EDG 在各种复杂分布任务中的卓越性能,其表现优于现有方法。
Report issue for preceding element
关键词: 玻尔兹曼分布 , 基于能量的模型 , 生成模型 , 扩散模型 , 变分自动编码器
\affiliation \隶属关系
organization= 同济大学数学科学学院,city= 上海、
国家=中国
/affiliation organization=Department of Mathematics, Shanghaiormal University,city=Shanghai, country=China /affiliation organization=School of Mathematical Sciences, Institute of Natural Sciences and MOE-LSC、 Shanghai Jiao Tong University,city=Shanghai, country=China, addressline=, *hwu81@sjtu.edu.cn affiliation\ organization=LSEC, Institute of Computational Mathematics and Scientific/Engineering Computing, AMSS, Chinese Academy of Sciences,city=Beijing, country=China
Report issue for preceding element
1 简介 Report issue for preceding element
在计算化学、统计物理和机器学习等多个领域,从与高维复杂能量函数相对应的玻尔兹曼分布中进行采样的挑战无处不在[1 , 2 ] 。与数据驱动生成模型的训练任务不同,波尔兹曼分布的采样任务可以利用预先采样的数据来学习复杂的分布,但由于缺乏现成的数据[3 , 4 ] ,因此带来了独特而巨大的挑战。例如,模拟伊辛模型的相变可以看作是给定能量函数的采样问题,这是一个复杂而困难的问题,至今尚未得到有效解决[5 , 6 ] 。
Report issue for preceding element
马尔可夫链蒙特卡罗(MCMC)方法[7 ] 以及布朗和哈密顿动力学[8 、9 , 10 , 11 ] 为解决从高维分布中采样的难题提供了关键的解决方案。这些方法通过迭代生成候选样本和更新样本,最终在无限采样步骤的极限实现渐近无偏性。近年来,研究人员提出了自适应 MCMC 作为生成候选样本的策略,在提高采样过程的效率和有效性方面取得了显著进展[12 , 13 , 14 ] 。然而,MCMC 的混合时间过长仍然制约了其性能。一些研究表明,在 MCMC 中使用神经网络构建和优化提议分布可以显著提高其效率[13 , 15 , 16 ] 。然而,目前仍缺乏有效且适应性广的损失函数来促进这种优化。
Report issue for preceding element
变量推理(Variational inference,VI)是解决棘手分布问题的另一种重要方法。变分推理利用能够快速生成样本的生成器来逼近目标波尔兹曼分布,然后对生成器的参数进行优化,以最小化生成样本的分布与目标分布之间的统计距离,如库尔巴克-莱伯勒(KL)发散。由于归一化流(NF)能够模拟复杂分布并提供明确的概率密度函数,它已被广泛应用于构建 VI 方法的生成器 [17 、18 , 19 , 20 , 21 , 22 、23 , 24 , 25 , 26 ] .然而,NF 的双射性质对其有效容量造成了限制,往往使其不足以完成某些采样任务。考虑到目标密度函数和生成的样本,斯坦因差异[27 , 28 ] 提供了另一种评估拟合优度的方法,而核函数及其梯度的计算限制了它在高维任务中的表现。此外,MCMC 与 VI 方法的结合也是当前研究的一个焦点 [29 、30 , 31 , 32 , 33 , 34 ] . 这种组合试图利用两种方法的优势,为解决与高维分布采样相关的挑战和提高概率建模的效率提供了一条很有前景的途径。
Report issue for preceding element
随着基于扩散的生成模型[35 , 36 , 37 , 38 ] 的蓬勃发展,它们已被应用于解决抽样问题中的难题。通过训练随时间变化的分数匹配神经网络,[39 , 40 , 41 ] 中提出的方法将高斯分布塑造成复杂的目标密度,并采用 KL 发散作为损失函数。为了缓解模式搜索问题,[42 ] 引入了对数方差损失,显示出良好的特性。此外,[43 ] 中概述了另一种训练目标,该目标依赖于能量函数的灵活插值,对多模态目标有很大改进。然而,这些方法的一个共同缺点是依赖数值微分方程求解器来计算时间积分,这可能会导致大量的计算成本。
Report issue for preceding element
在这项研究工作中,我们从变异自动编码器(VAE)技术[44 ] 和扩散模型中汲取灵感,提出了一种称为基于能量的扩散发生器(EDG)的新方法。EDG 的架构与 VAE 非常相似,包括一个解码器和一个编码器。解码器可以灵活地将根据可控分布分布的潜变量映射到样本,而无需施加诸如双向性之类的约束,我们在这项工作中设计了一个基于广义哈密尔顿动力学的解码器,以提高采样效率。编码器利用扩散过程,能够应用分数匹配技术对给定样本的潜变量条件分布进行精确高效的建模。与现有的基于扩散的方法不同,EDG 的损失函数便于以随机小批量方式计算无偏估计值,无需在训练过程中对常微分方程或随机微分方程进行数值求解。数值实验最终证明了 EDG 的有效性。
Report issue for preceding element
2 前言和设置 Report issue for preceding element
在这项工作中,我们将深入研究制作生成模型的任务,以便从预定能量 U : ℝ d → ℝ U:\mathbb{R}^{d}\to\mathbb{R}italic_U : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT → blackboard_R 驱动的玻尔兹曼分布中采样:
Report issue for preceding element
π ( x ) = 1 Z exp ( − U ( x ) ) , 𝜋 𝑥 1 𝑍 𝑈 𝑥 \pi(x)=\frac{1}{Z}\exp(-U(x)), italic_π ( italic_x ) = divide start_ARG 1 end_ARG start_ARG italic_Z end_ARG roman_exp ( - italic_U ( italic_x ) ) ,
其中归一化常数 Z = ∫ exp ( − U ( x ) ) d x Z=\int\exp(-U(x))\mathrm{d}xitalic_Z = ∫ roman_exp ( - italic_U ( italic_x ) ) roman_d italic_x 通常难以计算。为了应对这一挑战,玻尔兹曼发生器 [18 ] 及其各种扩展 [24 、25 , 26 ] , 近年来已成为一种突出的技术。这些方法利用 NF 对可训练的分析密度函数进行参数化,并通过代用密度与 π \piitalic_π 之间 KL 发散的最小化实现参数优化。然而,与典型的生成模型不同,追求精确的概率密度计算对 NF 施加了大量限制:每个变换层都必须是双射,而且其雅各矩阵的行列式可以轻松计算。这些要求从本质上限制了 NF 对复杂分布进行有效建模的能力。
Report issue for preceding element
现在,我们的重点转移到与 VAE 类似的信号发生器上。这种发生器通过解码器产生的采样为
Report issue for preceding element
x | z 0 ∼ p D ( x | z 0 , ϕ ) , similar-to conditional 𝑥 subscript 𝑧 0 subscript 𝑝 𝐷 conditional 𝑥 subscript 𝑧 0 italic-ϕ
x|z_{0}\sim p_{D}(x|z_{0},\phi), italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_ϕ ) ,
其中 z 0 ∼ p D ( z 0 ) z_{0}\sim p_{D}(z_{0})italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) 是一个从已知先验分布(通常是标准多元正态分布)中抽取的潜在变量。参数 ϕ \phiitalic_ϕ 是解码器的特征,我们将 p D ( x | z 0 ; ϕ ) p_{D}(x|z_{0};\phi)italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ; italic_ϕ ) 定义为高斯分布 𝒩 ( x | μ ( z 0 ; ϕ ) , Σ ( z 0 ; ϕ ) ) \mathcal{N}(x|\mu(z_{0};\phi),\Sigma(z_{0};\phi))caligraphic_N ( italic_x | italic_μ ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ; italic_ϕ ) , roman_Σ ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ; italic_ϕ ) ) ,其中 μ \muitalic_μ 和 Σ \Sigmaroman_Σ 均由神经网络 (NN) 参数化。与 VAE 类似,我们的目标是训练网络 μ \muitalic_μ 和 Σ \Sigmaroman_Σ 使生成样本的边际分布 p D ( x ) p_{D}(x)italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x ) 与目标分布一致。
Report issue for preceding element
需要注意的是,与传统的数据驱动 VAE 不同,我们无法获得目标分布 π ( x ) \pi(x)italic_π ( italic_x ) 的样本。事实上,获取此类样本正是生成器的目标。因此,KL 发散 D K L ( π ( x ) | | p D ( x ) ) D_{KL}\left(\pi(x)||p_{D}(x)\right)italic_D start_POSTSUBSCRIPT italic_K italic_L end_POSTSUBSCRIPT ( italic_π ( italic_x ) | | italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x ) ) 的变分近似值不能用于训练模型。相反,在这项工作中,我们考虑以下发散及其上限:
Report issue for preceding element
D K L ( p D ( x ) | | π ( x ) ) \displaystyle D_{KL}\left(p_{D}(x)||\pi(x)\right) italic_D start_POSTSUBSCRIPT italic_K italic_L end_POSTSUBSCRIPT ( italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x ) | | italic_π ( italic_x ) )
≤ \displaystyle\leq ≤
D K L ( p D ( z 0 ) ⋅ p D ( x | z 0 , ϕ ) | | π ( x ) ⋅ p E ( z 0 | x , θ ) ) \displaystyle D_{KL}\left(p_{D}(z_{0})\cdot p_{D}(x|z_{0},\phi)||\pi(x)\cdot p%
_{E}(z_{0}|x,\theta)\right) italic_D start_POSTSUBSCRIPT italic_K italic_L end_POSTSUBSCRIPT ( italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ⋅ italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_ϕ ) | | italic_π ( italic_x ) ⋅ italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x , italic_θ ) )
(1)
= \displaystyle= =
𝔼 p D ( z 0 ) ⋅ p D ( x | z 0 , ϕ ) [ log p D ( z 0 ) p D ( x | z 0 , ϕ ) p E ( z 0 | x , θ ) + U ( x ) ] subscript 𝔼 ⋅ subscript 𝑝 𝐷 subscript 𝑧 0 subscript 𝑝 𝐷 conditional 𝑥 subscript 𝑧 0 italic-ϕ
delimited-[] subscript 𝑝 𝐷 subscript 𝑧 0 subscript 𝑝 𝐷 conditional 𝑥 subscript 𝑧 0 italic-ϕ
subscript 𝑝 𝐸 conditional subscript 𝑧 0 𝑥 𝜃
𝑈 𝑥 \displaystyle\mathbb{E}_{p_{D}(z_{0})\cdot p_{D}(x|z_{0},\phi)}\left[\log\frac%
{p_{D}(z_{0})p_{D}(x|z_{0},\phi)}{p_{E}(z_{0}|x,\theta)}+U(x)\right] blackboard_E start_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ⋅ italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_ϕ ) end_POSTSUBSCRIPT [ roman_log divide start_ARG italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_ϕ ) end_ARG start_ARG italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x , italic_θ ) end_ARG + italic_U ( italic_x ) ]
+ log Z . 𝑍 \displaystyle+\log Z. + roman_log italic_Z .
在这里,参数分布 p E ( z 0 | x , θ ) p_{E}(z_{0}|x,\theta)italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x , italic_θ ) 定义了从 x xitalic_x 映射到潜变量 z zitalic_z 的编码器,如果 p E ( z | x , θ ) p_{E}(z|x,\theta)italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_z | italic_x , italic_θ ) 与从解码器推导出的给定 x xitalic_x 的 z zitalic_z 的条件分布相匹配,则实现了相等。
Report issue for preceding element
我们似乎只是增加了问题的复杂性,因为我们仍然需要近似条件分布。不过,在接下来的章节中,我们将演示如何利用扩散模型 [35 , 37 ] 有效地构建编码器,并优化所有参数,而无需数值求解常微分方程或随机微分方程。
Report issue for preceding element
3 能量扩散发生器 Report issue for preceding element
扩散模型 [37 , 38 ] 是近年来出现的一种估算数据分布的高效方法。其核心思想是构建一个扩散过程,逐步将数据转化为简单的白噪声,然后学习反向过程,从噪声中恢复数据分布。在这项工作中,我们应用了扩散模型的原理,在潜空间中加入了扩散过程,使我们能够有效地克服等式 (1 ) 所定义的采样问题的变分框架所带来的挑战。我们将这种方法产生的模型称为基于能量的扩散发生器 (EDG)。
Report issue for preceding element
3.1 模型架构 Report issue for preceding element
在 EDG 框架中,我们从潜在变量 z 0 z_{0}italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT 启动一个扩散过程,并将其与解码器相结合,形成我们所说的 "解码过程":
Report issue for preceding element
z 0 ∈ R D ∼ p D ( z 0 ) ≜ 𝒩 ( x | 0 , I ) , x | z 0 ∼ p D ( x | z 0 ; ϕ ) formulae-sequence subscript 𝑧 0 superscript 𝑅 𝐷 similar-to subscript 𝑝 𝐷 subscript 𝑧 0 ≜ 𝒩 conditional 𝑥 0 𝐼
similar-to conditional 𝑥 subscript 𝑧 0 subscript 𝑝 𝐷 conditional 𝑥 subscript 𝑧 0 italic-ϕ
\displaystyle z_{0}\in R^{D}\sim p_{D}(z_{0})\triangleq\mathcal{N}(x|0,I),%
\quad x|z_{0}\sim p_{D}(x|z_{0};\phi) italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ italic_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ∼ italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ≜ caligraphic_N ( italic_x | 0 , italic_I ) , italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ; italic_ϕ )
d z t = f ( z t , t ) d t + g ( t ) d W t , t ∈ [ 0 , T ] formulae-sequence d subscript 𝑧 𝑡 𝑓 subscript 𝑧 𝑡 𝑡 d 𝑡 𝑔 𝑡 d subscript 𝑊 𝑡 𝑡 0 𝑇 \displaystyle\mathrm{d}z_{t}=f(z_{t},t)\mathrm{d}t+g(t)\mathrm{d}W_{t},\quad t%
\in[0,T] roman_d italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_f ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) roman_d italic_t + italic_g ( italic_t ) roman_d italic_W start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ∈ [ 0 , italic_T ]
(2)
其中, W t W_{t}italic_W start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT 为标准维纳过程, f ( ⋅ , t ) : R D → R D f(\cdot,t):R^{D}\rightarrow R^{D}italic_f ( ⋅ , italic_t ) : italic_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT → italic_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT 为漂移系数, g ( ⋅ ) : R → R g(\cdot):R\rightarrow Ritalic_g ( ⋅ ) : italic_R → italic_R 为扩散系数。为了简化符号,我们将解码过程定义的概率分布表示为 p D p_{D}italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT 。在扩散模型中应用的典型 SDE 中,有两个关键条件:(a) 过渡密度 p D ( z t | z 0 ) p_{D}(z_{t}|z_{0})italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) 可以分析计算,而无需数值求解福克-普朗克方程;(b) z T z_{T}italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT 与 p D ( z T ) ≈ p D ( z T | z 0 ) p_{D}(z_{T})\approx p_{D}(z_{T}|z_{0})italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) ≈ italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) 近似无信息。
Report issue for preceding element
如果我们只考虑潜在扩散过程 z [ ⋅ ] = { z t } t ∈ [ 0 , T ] z_{[\cdot]}=\{z_{t}\}_{t\in[0,T]}italic_z start_POSTSUBSCRIPT [ ⋅ ] end_POSTSUBSCRIPT = { italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_t ∈ [ 0 , italic_T ] end_POSTSUBSCRIPT 的统计特性,那么它是没有信息量的,只能描述从一种简单噪声到另一种简单噪声的过渡。然而,当我们考虑到给定样本 x xitalic_x 时 z t z_{t}italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT 的条件分布时, z [ : ] z_{[:]}italic_z start_POSTSUBSCRIPT [ : ] end_POSTSUBSCRIPT 过程就代表了复杂条件分布 p D ( z 0 | x ) ∝ p D ( z 0 ) ⋅ p D ( x | z 0 ) p_{D}(z_{0}|x)\propto p_{D}(z_{0})\cdot p_{D}(x|z_{0})italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x ) ∝ italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ⋅ italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) 逐渐转变为可控分布 p D ( z T | x ) = p D ( z T ) p_{D}(z_{T}|x)=p_{D}(z_{T})italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT | italic_x ) = italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) 的过程、其中 z T z_{T}italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT 和 x xitalic_x 之间的独立性来自 z 0 z_{0}italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT 和 z T z_{T}italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT 之间的独立性(见 A )。这意味着,从 z T ∼ p D ( z T ) z_{T}\sim p_{D}(z_{T})italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ∼ italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) 开始,我们可以通过模拟以下反向时间扩散方程 [45 ] 从 p D ( z 0 | x ) p_{D}(z_{0}|x)italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x ) 获取样本:
Report issue for preceding element
d z t ~ = − ( f ( z t ~ , t ~ ) − g ( t ~ ) 2 ∇ z t ~ log p D ( z t ~ | x ) ) d t ~ + g ( t ~ ) d W t ~ , d subscript 𝑧 ~ 𝑡 𝑓 subscript 𝑧 ~ 𝑡 ~ 𝑡 𝑔 superscript ~ 𝑡 2 subscript ∇ subscript 𝑧 ~ 𝑡 subscript 𝑝 𝐷 conditional subscript 𝑧 ~ 𝑡 𝑥 d ~ 𝑡 𝑔 ~ 𝑡 d subscript 𝑊 ~ 𝑡 \mathrm{d}z_{\tilde{t}}=-\left(f(z_{\tilde{t}},\tilde{t})-g(\tilde{t})^{2}%
\nabla_{z_{\tilde{t}}}\log p_{D}(z_{\tilde{t}}|x)\right)\mathrm{d}\tilde{t}+g(%
\tilde{t})\mathrm{d}W_{\tilde{t}}, roman_d italic_z start_POSTSUBSCRIPT over~ start_ARG italic_t end_ARG end_POSTSUBSCRIPT = - ( italic_f ( italic_z start_POSTSUBSCRIPT over~ start_ARG italic_t end_ARG end_POSTSUBSCRIPT , over~ start_ARG italic_t end_ARG ) - italic_g ( over~ start_ARG italic_t end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT over~ start_ARG italic_t end_ARG end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT over~ start_ARG italic_t end_ARG end_POSTSUBSCRIPT | italic_x ) ) roman_d over~ start_ARG italic_t end_ARG + italic_g ( over~ start_ARG italic_t end_ARG ) roman_d italic_W start_POSTSUBSCRIPT over~ start_ARG italic_t end_ARG end_POSTSUBSCRIPT ,
(3)
其中 t ~ = T − t \tilde{t}=T-tover~ start_ARG italic_t end_ARG = italic_T - italic_t 表示反向时间。与传统的扩散模型一样,由于得分函数 ∇ z t ~ log p D ( z t ~ | x ) \nabla_{z_{\tilde{t}}}\log p_{D}(z_{\tilde{t}}|x)∇ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT over~ start_ARG italic_t end_ARG end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT over~ start_ARG italic_t end_ARG end_POSTSUBSCRIPT | italic_x ) 的难处理性,这种模拟的实际实施具有挑战性,因此我们也使用神经网络来近似得分函数,表示为 s ( z t ~ , x , t ~ ; θ ) s(z_{\tilde{t}},x,\tilde{t};\theta)italic_s ( italic_z start_POSTSUBSCRIPT over~ start_ARG italic_t end_ARG end_POSTSUBSCRIPT , italic_x , over~ start_ARG italic_t end_ARG ; italic_θ ) 。这种近似导致了我们所说的 "编码过程",它是通过整合参数反向时间扩散过程和 x xitalic_x 的目标分布来实现的:
Report issue for preceding element
x ∼ π ( x ) , z T ∼ p E ( z T ) ≜ p D ( z T ) formulae-sequence similar-to 𝑥 𝜋 𝑥 similar-to subscript 𝑧 𝑇 subscript 𝑝 𝐸 subscript 𝑧 𝑇 ≜ subscript 𝑝 𝐷 subscript 𝑧 𝑇 \displaystyle x\sim\pi(x),\quad z_{T}\sim p_{E}(z_{T})\triangleq p_{D}(z_{T}) italic_x ∼ italic_π ( italic_x ) , italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ∼ italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) ≜ italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT )
d z t ~ = − ( f ( z t ~ , t ~ ) − g ( t ~ ) 2 s ( z t ~ , x , t ~ ; θ ) ) d t ~ + g ( t ~ ) d W t ~ , t ~ = T − t . formulae-sequence d subscript 𝑧 ~ 𝑡 𝑓 subscript 𝑧 ~ 𝑡 ~ 𝑡 𝑔 superscript ~ 𝑡 2 𝑠 subscript 𝑧 ~ 𝑡 𝑥 ~ 𝑡 𝜃 d ~ 𝑡 𝑔 ~ 𝑡 d subscript 𝑊 ~ 𝑡 ~ 𝑡 𝑇 𝑡 \displaystyle\mathrm{d}z_{\tilde{t}}=-\left(f(z_{\tilde{t}},\tilde{t})-g(%
\tilde{t})^{2}s(z_{\tilde{t}},x,\tilde{t};\theta)\right)\mathrm{d}\tilde{t}+g(%
\tilde{t})\mathrm{d}W_{\tilde{t}},\quad\tilde{t}=T-t. roman_d italic_z start_POSTSUBSCRIPT over~ start_ARG italic_t end_ARG end_POSTSUBSCRIPT = - ( italic_f ( italic_z start_POSTSUBSCRIPT over~ start_ARG italic_t end_ARG end_POSTSUBSCRIPT , over~ start_ARG italic_t end_ARG ) - italic_g ( over~ start_ARG italic_t end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_s ( italic_z start_POSTSUBSCRIPT over~ start_ARG italic_t end_ARG end_POSTSUBSCRIPT , italic_x , over~ start_ARG italic_t end_ARG ; italic_θ ) ) roman_d over~ start_ARG italic_t end_ARG + italic_g ( over~ start_ARG italic_t end_ARG ) roman_d italic_W start_POSTSUBSCRIPT over~ start_ARG italic_t end_ARG end_POSTSUBSCRIPT , over~ start_ARG italic_t end_ARG = italic_T - italic_t .
(4)
为简化符号,本文中我们将编码过程定义的分布称为 p E p_{E}italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT 。
Report issue for preceding element
图 1 直观描述了解码和编码过程。需要强调的是,潜在扩散模型主要用于解决数据驱动场景中的生成建模问题,最近受到了广泛关注[46 , 47 , 48 ] 。他们的主要想法是使用预先训练好的编码器和解码器来获得一个既能有效表示数据又能促进高效采样的潜在空间,并通过扩散模型来学习潜在变量的分布。我们的 EDG 模型利用类似的理念来解决基于能量的采样问题。EDG 与之前的潜变量扩散模型在结构和算法上的主要区别如下:首先,在 EDG 中,扩散模型本身就是编码器,无需单独的编码器;其次,通过使用统一的损失函数,解码器与扩散模型共同训练(见第 3.2 节)。
Report issue for preceding element
图 1: 解码和编码过程的概率图,灰色部分为模型的可训练部分。
Report issue for preceding element
下面,我们将介绍 EDG 模块的构造细节,这些模块将在我们的实验中使用。在实际应用中,可以根据需要设计更有效的神经网络。
Report issue for preceding element
3.1.1 边界条件引导的分数函数模型 Report issue for preceding element
考虑到真实分数函数满足 t = 0 , T t=0,Titalic_t = 0 , italic_T 的以下边界条件:
Report issue for preceding element
∇ z 0 log p D ( z 0 | x ) subscript ∇ subscript 𝑧 0 subscript 𝑝 𝐷 conditional subscript 𝑧 0 𝑥 \displaystyle\nabla_{z_{0}}\log p_{D}(z_{0}|x) ∇ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x )
= \displaystyle= =
∇ z 0 [ log p D ( z 0 | x ) + log p D ( x ) ] subscript ∇ subscript 𝑧 0 subscript 𝑝 𝐷 conditional subscript 𝑧 0 𝑥 subscript 𝑝 𝐷 𝑥 \displaystyle\nabla_{z_{0}}\left[\log p_{D}(z_{0}|x)+\log p_{D}(x)\right] ∇ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x ) + roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x ) ]
= \displaystyle= =
∇ z 0 log p D ( x , z 0 ) subscript ∇ subscript 𝑧 0 subscript 𝑝 𝐷 𝑥 subscript 𝑧 0 \displaystyle\nabla_{z_{0}}\log p_{D}(x,z_{0}) ∇ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT )
= \displaystyle= =
∇ z 0 [ log p D ( x | z 0 ) + log p D ( z 0 ) ] subscript ∇ subscript 𝑧 0 subscript 𝑝 𝐷 conditional 𝑥 subscript 𝑧 0 subscript 𝑝 𝐷 subscript 𝑧 0 \displaystyle\nabla_{z_{0}}\left[\log p_{D}(x|z_{0})+\log p_{D}(z_{0})\right] ∇ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) + roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ]
and 和
Report issue for preceding element
∇ z T log p D ( z T | x ) = ∇ z T log p D ( z T ) , subscript ∇ subscript 𝑧 𝑇 subscript 𝑝 𝐷 conditional subscript 𝑧 𝑇 𝑥 subscript ∇ subscript 𝑧 𝑇 subscript 𝑝 𝐷 subscript 𝑧 𝑇 \nabla_{z_{T}}\log p_{D}(z_{T}|x)=\nabla_{z_{T}}\log p_{D}(z_{T}), ∇ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT | italic_x ) = ∇ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) ,
我们建议将 s ( z , x , t ; θ ) s(z,x,t;\theta)italic_s ( italic_z , italic_x , italic_t ; italic_θ ) 表述为
Report issue for preceding element
s ( z , x , t ; θ ) 𝑠 𝑧 𝑥 𝑡 𝜃 \displaystyle s(z,x,t;\theta) italic_s ( italic_z , italic_x , italic_t ; italic_θ )
= \displaystyle= =
( 1 − t T ) ⋅ ∇ z 0 [ log p D ( x | z 0 = z ) + log p D ( z 0 = z ) ] ⋅ 1 𝑡 𝑇 subscript ∇ subscript 𝑧 0 subscript 𝑝 𝐷 conditional 𝑥 subscript 𝑧 0 𝑧 subscript 𝑝 𝐷 subscript 𝑧 0 𝑧 \displaystyle\left(1-\frac{t}{T}\right)\cdot\nabla_{z_{0}}\left[\log p_{D}(x|z%
_{0}=z)+\log p_{D}(z_{0}=z)\right] ( 1 - divide start_ARG italic_t end_ARG start_ARG italic_T end_ARG ) ⋅ ∇ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = italic_z ) + roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = italic_z ) ]
+ t T ⋅ ∇ z T log p D ( z T = z ) + t T ( 1 − t T ) s ′ ( z , x , t ; θ ) , ⋅ 𝑡 𝑇 subscript ∇ subscript 𝑧 𝑇 subscript 𝑝 𝐷 subscript 𝑧 𝑇 𝑧 𝑡 𝑇 1 𝑡 𝑇 superscript 𝑠 ′ 𝑧 𝑥 𝑡 𝜃 \displaystyle+\frac{t}{T}\cdot\nabla_{z_{T}}\log p_{D}(z_{T}=z)+\frac{t}{T}%
\left(1-\frac{t}{T}\right)s^{\prime}(z,x,t;\theta), + divide start_ARG italic_t end_ARG start_ARG italic_T end_ARG ⋅ ∇ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT = italic_z ) + divide start_ARG italic_t end_ARG start_ARG italic_T end_ARG ( 1 - divide start_ARG italic_t end_ARG start_ARG italic_T end_ARG ) italic_s start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_z , italic_x , italic_t ; italic_θ ) ,
其中, s ′ ( z , x , t ; θ ) s^{\prime}(z,x,t;\theta)italic_s start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_z , italic_x , italic_t ; italic_θ ) 是要训练的神经网络。这种表述方式可确保 s sitalic_s 的误差在 t = 0 t=0italic_t = 0 和 t = T t=Titalic_t = italic_T 中均为零。
Report issue for preceding element
3.1.2 基于广义哈密顿动力学的解码器 Report issue for preceding element
受广义哈密顿动力学(GHD)[12 , 13 ] 的启发,解码器通过以下过程生成输出 x xitalic_x 。首先,根据潜变量 z 0 z_{0}italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT 生成初始样本和速度 ( y , v y,vitalic_y , italic_v )。然后,对 ( y , v ) (y,v)( italic_y , italic_v ) 进行如下迭代更新:
Report issue for preceding element
v 𝑣 \displaystyle v italic_v
:= assign \displaystyle:= :=
v − ϵ ( l ; ϕ ) 2 ( ∇ U ( y ) ⊙ e ϵ 0 2 Q v ( y , ∇ U ( y ) , l ; ϕ ) + T v ( y , ∇ U ( y ) , l ; ϕ ) ) , 𝑣 italic-ϵ 𝑙 italic-ϕ
2 direct-product ∇ 𝑈 𝑦 superscript 𝑒 subscript italic-ϵ 0 2 subscript 𝑄 𝑣 𝑦 ∇ 𝑈 𝑦 𝑙 italic-ϕ subscript 𝑇 𝑣 𝑦 ∇ 𝑈 𝑦 𝑙 italic-ϕ \displaystyle v-\frac{\epsilon(l;\phi)}{2}\left(\nabla U(y)\odot e^{\frac{%
\epsilon_{0}}{2}Q_{v}(y,\nabla U(y),l;\phi)}+T_{v}(y,\nabla U(y),l;\phi)\right), italic_v - divide start_ARG italic_ϵ ( italic_l ; italic_ϕ ) end_ARG start_ARG 2 end_ARG ( ∇ italic_U ( italic_y ) ⊙ italic_e start_POSTSUPERSCRIPT divide start_ARG italic_ϵ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG italic_Q start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ( italic_y , ∇ italic_U ( italic_y ) , italic_l ; italic_ϕ ) end_POSTSUPERSCRIPT + italic_T start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ( italic_y , ∇ italic_U ( italic_y ) , italic_l ; italic_ϕ ) ) ,
y 𝑦 \displaystyle y italic_y
:= assign \displaystyle:= :=
y + ϵ ( l ; ϕ ) ( v k ⊙ e ϵ 0 Q y ( v k , l ; ϕ ) + T y ( v k , l ; ϕ ) ) , 𝑦 italic-ϵ 𝑙 italic-ϕ
direct-product subscript 𝑣 𝑘 superscript 𝑒 subscript italic-ϵ 0 subscript 𝑄 𝑦 subscript 𝑣 𝑘 𝑙 italic-ϕ subscript 𝑇 𝑦 subscript 𝑣 𝑘 𝑙 italic-ϕ \displaystyle y+\epsilon(l;\phi)\left(v_{k}\odot e^{\epsilon_{0}Q_{y}(v_{k},l;%
\phi)}+T_{y}(v_{k},l;\phi)\right), italic_y + italic_ϵ ( italic_l ; italic_ϕ ) ( italic_v start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ⊙ italic_e start_POSTSUPERSCRIPT italic_ϵ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_Q start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ( italic_v start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_l ; italic_ϕ ) end_POSTSUPERSCRIPT + italic_T start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ( italic_v start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_l ; italic_ϕ ) ) ,
v 𝑣 \displaystyle v italic_v
:= assign \displaystyle:= :=
v − ϵ ( l ; ϕ ) 2 ( ∇ U ( y ) ⊙ e ϵ 0 2 Q v ( y , ∇ U ( y ) , l ; ϕ ) + T v ( y , ∇ U ( y ) , l ; ϕ ) ) . 𝑣 italic-ϵ 𝑙 italic-ϕ
2 direct-product ∇ 𝑈 𝑦 superscript 𝑒 subscript italic-ϵ 0 2 subscript 𝑄 𝑣 𝑦 ∇ 𝑈 𝑦 𝑙 italic-ϕ subscript 𝑇 𝑣 𝑦 ∇ 𝑈 𝑦 𝑙 italic-ϕ \displaystyle v-\frac{\epsilon(l;\phi)}{2}\left(\nabla U(y)\odot e^{\frac{%
\epsilon_{0}}{2}Q_{v}(y,\nabla U(y),l;\phi)}+T_{v}(y,\nabla U(y),l;\phi)\right). italic_v - divide start_ARG italic_ϵ ( italic_l ; italic_ϕ ) end_ARG start_ARG 2 end_ARG ( ∇ italic_U ( italic_y ) ⊙ italic_e start_POSTSUPERSCRIPT divide start_ARG italic_ϵ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG italic_Q start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ( italic_y , ∇ italic_U ( italic_y ) , italic_l ; italic_ϕ ) end_POSTSUPERSCRIPT + italic_T start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ( italic_y , ∇ italic_U ( italic_y ) , italic_l ; italic_ϕ ) ) .
最后,解码器输出 x xitalic_x 由以下公式给出:
Report issue for preceding element
x = y − exp ( ϵ 0 η ( y ; ϕ ) ) ∇ U ( y ) + 2 exp ( ϵ 0 η ( y ; ϕ ) ) ξ 𝑥 𝑦 subscript italic-ϵ 0 𝜂 𝑦 italic-ϕ
∇ 𝑈 𝑦 2 subscript italic-ϵ 0 𝜂 𝑦 italic-ϕ
𝜉 x=y-\exp(\epsilon_{0}\eta(y;\phi))\nabla U(y)+2\exp(\epsilon_{0}\eta(y;\phi))\xi italic_x = italic_y - roman_exp ( italic_ϵ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_η ( italic_y ; italic_ϕ ) ) ∇ italic_U ( italic_y ) + 2 roman_exp ( italic_ϵ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_η ( italic_y ; italic_ϕ ) ) italic_ξ
其中 ξ \xiitalic_ξ 是按照标准高斯分布分布的。该方程可解释为布朗动力学 d y = − ∇ U ( y ) d t + d W t \mathrm{d}y=-\nabla U(y)\,\mathrm{d}t+\mathrm{d}W_{t}roman_d italic_y = - ∇ italic_U ( italic_y ) roman_d italic_t + roman_d italic_W start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT 的有限步近似。在上述过程中, Q v Q_{v}italic_Q start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT 、 T v T_{v}italic_T start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT 、 Q y Q_{y}italic_Q start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT 、 T y T_{y}italic_T start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT 、 ϵ \epsilonitalic_ϵ 和 η \etaitalic_η 都是可训练的神经网络。
Report issue for preceding element
基于 GHD 的解码器有两方面的主要优势。首先,它能有效利用能量函数的梯度信息,我们的实验表明,它能提高多模态分布的采样性能。其次,通过在经典哈密顿动力学中加入可训练的修正项和步骤,它只需几次迭代就能达到很好的解码密度。完整的解码过程请参见C 。
Report issue for preceding element
3.2 损失函数 Report issue for preceding element
为了优化解码和编码过程的参数,我们可以最小化解码和编码过程提供的 ( x , z [ ⋅ ] ) (x,z_{[\cdot]})( italic_x , italic_z start_POSTSUBSCRIPT [ ⋅ ] end_POSTSUBSCRIPT ) 联合分布之间的 KL 分歧。由于数据处理不等式,该发散也是 D K L ( p D ( x ) | | π ( x ) ) D_{KL}(p_{D}(x)||\pi(x))italic_D start_POSTSUBSCRIPT italic_K italic_L end_POSTSUBSCRIPT ( italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x ) | | italic_π ( italic_x ) ) 的上界,就像(1 )中的上界一样。根据 B 中的推导,KL 发散可表示为
Report issue for preceding element
D K L ( p D ( x , z [ ⋅ ] ) | | p E ( x , z [ ⋅ ] ) ) = ℒ ( θ , ϕ ) + log Z , D_{KL}\left(p_{D}(x,z_{[\cdot]})||p_{E}(x,z_{[\cdot]})\right)=\mathcal{L}(%
\theta,\phi)+\log Z, italic_D start_POSTSUBSCRIPT italic_K italic_L end_POSTSUBSCRIPT ( italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x , italic_z start_POSTSUBSCRIPT [ ⋅ ] end_POSTSUBSCRIPT ) | | italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_x , italic_z start_POSTSUBSCRIPT [ ⋅ ] end_POSTSUBSCRIPT ) ) = caligraphic_L ( italic_θ , italic_ϕ ) + roman_log italic_Z ,
(5)
where 其中
Report issue for preceding element
ℒ ( θ , ϕ ) ℒ 𝜃 italic-ϕ \displaystyle\mathcal{L}(\theta,\phi) caligraphic_L ( italic_θ , italic_ϕ )
= \displaystyle= =
𝔼 p D [ log p D ( x | z 0 ; ϕ ) + U ( x ) ] subscript 𝔼 subscript 𝑝 𝐷 delimited-[] subscript 𝑝 𝐷 conditional 𝑥 subscript 𝑧 0 italic-ϕ
𝑈 𝑥 \displaystyle\mathbb{E}_{p_{D}}\left[\log p_{D}(x|z_{0};\phi)+U(x)\right] blackboard_E start_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ; italic_ϕ ) + italic_U ( italic_x ) ]
(6)
+ ∫ 0 T g ( t ) 2 2 𝔼 p D [ ∥ s ( z t , x , t ; θ ) ∥ 2 + 2 ∇ z t ⋅ s ( z t , x , t ; θ ) \displaystyle+\int_{0}^{T}\frac{g(t)^{2}}{2}\mathbb{E}_{p_{D}}\Bigg{[}\left\|s%
\left(z_{t},x,t;\theta\right)\right\|^{2}+2\nabla_{z_{t}}\cdot s\left(z_{t},x,%
t;\theta\right) + ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT divide start_ARG italic_g ( italic_t ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 end_ARG blackboard_E start_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∥ italic_s ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x , italic_t ; italic_θ ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + 2 ∇ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⋅ italic_s ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x , italic_t ; italic_θ )
+ ∥ ∇ z t log p D ( z t ) ∥ 2 ] d t . \displaystyle\hskip 170.71652pt+\left\|\nabla_{z_{t}}\log p_{D}(z_{t})\right\|%
^{2}\Bigg{]}\mathrm{d}t. + ∥ ∇ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] roman_d italic_t .
通过利用重要性采样积分和 Hutchinson 估计器,我们可以得到 ℒ ( θ , ϕ ) \mathcal{L}(\theta,\phi)caligraphic_L ( italic_θ , italic_ϕ ) 的等效表达式,该表达式可以通过蒙特卡罗随机抽样从 p D p_{D}italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT 中高效、无偏地估计出来,而无需对 SDE 进行数值求解:
Report issue for preceding element
ℒ ( θ , ϕ ) ℒ 𝜃 italic-ϕ \displaystyle\mathcal{L}(\theta,\phi) caligraphic_L ( italic_θ , italic_ϕ )
= \displaystyle= =
𝔼 p D [ log p D ( x | z 0 ; ϕ ) + U ( x ) ] subscript 𝔼 subscript 𝑝 𝐷 delimited-[] subscript 𝑝 𝐷 conditional 𝑥 subscript 𝑧 0 italic-ϕ
𝑈 𝑥 \displaystyle\mathbb{E}_{p_{D}}\left[\log p_{D}(x|z_{0};\phi)+U(x)\right] blackboard_E start_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ; italic_ϕ ) + italic_U ( italic_x ) ]
(7)
+ ∫ 0 T λ ( t ) ⋅ 𝔼 p ( t ) p ( ϵ ) p D ( x , z t ) [ ℒ t ( x , z t , ϵ ; θ ) ] d t , superscript subscript 0 𝑇 ⋅ 𝜆 𝑡 subscript 𝔼 𝑝 𝑡 𝑝 italic-ϵ subscript 𝑝 𝐷 𝑥 subscript 𝑧 𝑡 delimited-[] subscript ℒ 𝑡 𝑥 subscript 𝑧 𝑡 italic-ϵ 𝜃 differential-d 𝑡 \displaystyle+\int_{0}^{T}\lambda(t)\cdot\mathbb{E}_{p(t)p(\epsilon)p_{D}(x,z_%
{t})}\left[\mathcal{L}_{t}(x,z_{t},\epsilon;\theta)\right]\mathrm{d}t, + ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_λ ( italic_t ) ⋅ blackboard_E start_POSTSUBSCRIPT italic_p ( italic_t ) italic_p ( italic_ϵ ) italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x , italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ caligraphic_L start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x , italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_ϵ ; italic_θ ) ] roman_d italic_t ,
where 其中
Report issue for preceding element
ℒ t ( x , z t , ϵ ; θ ) = ‖ s ( z t , x , t ; θ ) ‖ 2 + 2 ∂ [ ϵ ⊤ s ( z t , x , t ; θ ) ] ∂ z t ϵ + ‖ ∇ z t log p D ( z t ) ‖ 2 , subscript ℒ 𝑡 𝑥 subscript 𝑧 𝑡 italic-ϵ 𝜃 superscript norm 𝑠 subscript 𝑧 𝑡 𝑥 𝑡 𝜃 2 2 delimited-[] superscript italic-ϵ top 𝑠 subscript 𝑧 𝑡 𝑥 𝑡 𝜃 subscript 𝑧 𝑡 italic-ϵ superscript norm subscript ∇ subscript 𝑧 𝑡 subscript 𝑝 𝐷 subscript 𝑧 𝑡 2 \mathcal{L}_{t}(x,z_{t},\epsilon;\theta)=\left\|s\left(z_{t},x,t;\theta\right)%
\right\|^{2}+2\frac{\partial\left[\epsilon^{\top}s\left(z_{t},x,t;\theta\right%
)\right]}{\partial z_{t}}\epsilon+\left\|\nabla_{z_{t}}\log p_{D}(z_{t})\right%
\|^{2}, caligraphic_L start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x , italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_ϵ ; italic_θ ) = ∥ italic_s ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x , italic_t ; italic_θ ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + 2 divide start_ARG ∂ [ italic_ϵ start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_s ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x , italic_t ; italic_θ ) ] end_ARG start_ARG ∂ italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG italic_ϵ + ∥ ∇ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ,
p ( ϵ ) p(\epsilon)italic_p ( italic_ϵ ) 是具有 𝔼 [ ϵ ] = 0 , cov ( ϵ ) = I \mathbb{E}[\epsilon]=0,\mathrm{cov}(\epsilon)=Iblackboard_E [ italic_ϵ ] = 0 , roman_cov ( italic_ϵ ) = italic_I 的拉德马赫分布, p ( t ) p(t)italic_p ( italic_t ) 是 t ∈ [ 0 , T ] t\in[0,T]italic_t ∈ [ 0 , italic_T ] 的提议分布,以及加权函数 λ ( t ) = g ( t ) 2 2 p ( t ) \lambda(t)=\frac{g(t)^{2}}{2p(t)}italic_λ ( italic_t ) = divide start_ARG italic_g ( italic_t ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 italic_p ( italic_t ) end_ARG 。然后,可以使用随机梯度下降法最小化损失函数(7 )来训练 EDG 中涉及的所有神经网络。
Report issue for preceding element
在我们的实验中,为了减小蒙特卡罗损失函数近似中固有的方差,我们采用了 [49 ] 中的策略,根据历史缓冲区动态调整时间建议分布 p D ( t ) p_{D}(t)italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_t ) ,该缓冲区存储了与 t titalic_t 相关的最近损失值的 r.h.s of (7 ).更多详情,请参阅 D 。
Report issue for preceding element
3.3 样本重新加权 Report issue for preceding element
训练完成后,我们可以使用解码器 p D ( x | z 0 ) p_{D}(x|z_{0})italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) 生成样本,并计算目标分布 π ( x ) \pi(x)italic_π ( italic_x ) 的各种统计数据。例如,对于感兴趣的数量 O : ℝ d → ℝ O:\mathbb{R}^{d}\to\mathbb{R}italic_O : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT → blackboard_R ,我们可以从 p D ( z 0 ) p D ( x | z 0 ) p_{D}(z_{0})p_{D}(x|z_{0})italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) 中提取 N Nitalic_N 增强样本 { ( x n , z 0 n ) } n = 1 N \{(x^{n},z_{0}^{n})\}_{n=1}^{N}{ ( italic_x start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ) } start_POSTSUBSCRIPT italic_n = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ,并按如下方法估计期望值 𝔼 π ( x ) [ O ( x ) ] \mathbb{E}_{\pi(x)}[O(x)]blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x ) end_POSTSUBSCRIPT [ italic_O ( italic_x ) ] :
Report issue for preceding element
𝔼 π ( x ) [ O ( x ) ] ≈ 1 N ∑ n O ( x n ) . subscript 𝔼 𝜋 𝑥 delimited-[] 𝑂 𝑥 1 𝑁 subscript 𝑛 𝑂 superscript 𝑥 𝑛 \mathbb{E}_{\pi(x)}[O(x)]\approx\frac{1}{N}\sum_{n}O(x^{n}). blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x ) end_POSTSUBSCRIPT [ italic_O ( italic_x ) ] ≈ divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT italic_O ( italic_x start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ) .
然而,由于模型误差,这种估计可能会出现系统性偏差。为了解决这个问题,我们可以使用重要性采样,将 p D ( x , z 0 ) = p D ( z 0 ) p D ( x | z 0 ) p_{D}(x,z_{0})=p_{D}(z_{0})p_{D}(x|z_{0})italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) 作为建议分布,将 p E ( x , z 0 ) = π ( x ) p E ( z 0 | x ) p_{E}(x,z_{0})=\pi(x)p_{E}(z_{0}|x)italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_x , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = italic_π ( italic_x ) italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x ) 作为增强目标分布。然后,我们可以为解码器生成的每个样本 ( x , z 0 ) (x,z_{0})( italic_x , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) 分配一个非规范化权重:
Report issue for preceding element
w ( x , z 0 ) = exp ( − U ( x ) ) p E ( z 0 | x ) p D ( z 0 ) p D ( x | z 0 ) ∝ π ( x ) p E ( z 0 | x ) p D ( z 0 ) p D ( x | z 0 ) 𝑤 𝑥 subscript 𝑧 0 𝑈 𝑥 subscript 𝑝 𝐸 conditional subscript 𝑧 0 𝑥 subscript 𝑝 𝐷 subscript 𝑧 0 subscript 𝑝 𝐷 conditional 𝑥 subscript 𝑧 0 proportional-to 𝜋 𝑥 subscript 𝑝 𝐸 conditional subscript 𝑧 0 𝑥 subscript 𝑝 𝐷 subscript 𝑧 0 subscript 𝑝 𝐷 conditional 𝑥 subscript 𝑧 0 w(x,z_{0})=\frac{\exp(-U(x))p_{E}(z_{0}|x)}{p_{D}(z_{0})p_{D}(x|z_{0})}\propto%
\frac{\pi(x)p_{E}(z_{0}|x)}{p_{D}(z_{0})p_{D}(x|z_{0})} italic_w ( italic_x , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = divide start_ARG roman_exp ( - italic_U ( italic_x ) ) italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x ) end_ARG start_ARG italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_ARG ∝ divide start_ARG italic_π ( italic_x ) italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x ) end_ARG start_ARG italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_ARG
(8)
并得到 𝔼 π ( x ) [ O ( x ) ] \mathbb{E}_{\pi(x)}[O(x)]blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x ) end_POSTSUBSCRIPT [ italic_O ( italic_x ) ] 的一致估计值:
Report issue for preceding element
𝔼 π ( x ) [ O ( x ) ] ≈ ∑ n w ( x n , z 0 n ) O ( x n ) ∑ n w ( x n , z 0 n ) , subscript 𝔼 𝜋 𝑥 delimited-[] 𝑂 𝑥 subscript 𝑛 𝑤 superscript 𝑥 𝑛 superscript subscript 𝑧 0 𝑛 𝑂 superscript 𝑥 𝑛 subscript 𝑛 𝑤 superscript 𝑥 𝑛 superscript subscript 𝑧 0 𝑛 \mathbb{E}_{\pi(x)}[O(x)]\approx\frac{\sum_{n}w(x^{n},z_{0}^{n})O(x^{n})}{\sum%
_{n}w(x^{n},z_{0}^{n})}, blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x ) end_POSTSUBSCRIPT [ italic_O ( italic_x ) ] ≈ divide start_ARG ∑ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT italic_w ( italic_x start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ) italic_O ( italic_x start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT italic_w ( italic_x start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ) end_ARG ,
其中估计误差随着 N → ∞ N\to\inftyitalic_N → ∞ 的变化趋近于零。
Report issue for preceding element
此外,权重函数 w witalic_w 还可用于估计归一化常数 Z Zitalic_Z ,这在许多应用中都是一项关键任务,例如统计学中的贝叶斯模型选择和统计物理学中的自由能估计。根据 (1 ), 我们可以得出:
Report issue for preceding element
log Z ≥ − 𝔼 p D ( x , z 0 ) [ log w ( x , z 0 ) ] . 𝑍 subscript 𝔼 subscript 𝑝 𝐷 𝑥 subscript 𝑧 0 delimited-[] 𝑤 𝑥 subscript 𝑧 0 \log Z\geq-\mathbb{E}_{p_{D}(x,z_{0})}\left[\log w(x,z_{0})\right]. roman_log italic_Z ≥ - blackboard_E start_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ roman_log italic_w ( italic_x , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ] .
在这里,下限也可以使用解码器的样本来估算,当 p D ( x , z 0 ) = p E ( x , z 0 ) p_{D}(x,z_{0})=p_{E}(x,z_{0})italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_x , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_x , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) 时,就能达到这个下限的严格程度。
Report issue for preceding element
上述计算的主要困难在于计算权重函数时,边际解码器密度 p E ( z 0 | x ) p_{E}(z_{0}|x)italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x ) 的难易程度。为了克服这一困难,我们根据 [37 ] 中第 4.3 节的结论,构建了以下概率流常微分方程(ODE):
Report issue for preceding element
d z t = ( f ( z t , t ) − 1 2 g ( t ) 2 s ( z t , x , t ; θ ) ) d t , d subscript 𝑧 𝑡 𝑓 subscript 𝑧 𝑡 𝑡 1 2 𝑔 superscript 𝑡 2 𝑠 subscript 𝑧 𝑡 𝑥 𝑡 𝜃 d 𝑡 \mathrm{d}z_{t}=\left(f(z_{t},t)-\frac{1}{2}g(t)^{2}s(z_{t},x,t;\theta)\right)%
\mathrm{d}t, roman_d italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = ( italic_f ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_g ( italic_t ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_s ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x , italic_t ; italic_θ ) ) roman_d italic_t ,
与边界条件 z T ∼ p E ( z T ) z_{T}\sim p_{E}(z_{T})italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ∼ italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) 。如果 s ( z t , x , t ; θ ) s(z_{t},x,t;\theta)italic_s ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x , italic_t ; italic_θ ) 在训练后能准确逼近得分函数 ∇ z t log p E ( z t | x ) \nabla_{z_{t}}\log p_{E}(z_{t}|x)∇ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x ) ,那么 ODE 给出的 z t | x z_{t}|xitalic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x 的条件分布将与每个 t ∈ [ 0 , 1 ] t\in[0,1]italic_t ∈ [ 0 , 1 ] 的编码过程相匹配。因此,我们可以使用神经 ODE 方法 [50 ] 高效计算 p E ( z 0 | x ) p_{E}(z_{0}|x)italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x ) 。
Report issue for preceding element
4 实验 Report issue for preceding element
我们对各种能量函数的 EDG 进行了实证评估。首先,我们介绍了从一组二维分布中获得的结果。接着,我们展示了 EDG 在贝叶斯逻辑回归中的表现。最后,我们将 EDG 应用于一个伊辛模型。E 中提供了所有实验细节。此外,我们还进行了一项消融研究,以验证 EDG 中每个模块的有效性。更多信息请参阅F 。
Report issue for preceding element
为了证明我们的模型的优越性,我们将 EDG 与以下抽样方法进行了比较:
Report issue for preceding element
1.Vanilla Hamiltonian Monte Carlo 法[ 8] ,简称 V-HMC。
2.L2HMC [ 13] 是一种基于 GHD 的 MCMC 方法,具有可训练的提议分布模型。
3.Boltzmann Generator(BG)[ 18],这是一种 VI 方法,使用 RealNVP 对代理分布进行建模[ 51]。
4.神经重正化组(NeuralRG)[ 17],一种类似于 BG 的方法,专为伊辛模型设计。在本节中,NeuralRG 仅用于伊辛模型的实验。
5.路径积分采样器(PIS)[ 39],这是一种通过对 SDE 进行数值模拟的基于扩散的采样模型。
表 1: 每个生成器生成的样本与参考样本之间的最大平均差异(MMD)。有关差异计算的详细信息,请参阅 E 。
Report issue for preceding element
图 2: 二维能量函数的密度图。关于参考样本的生成,请参阅 E 。我们为每种方法生成 500 , 000 500,000500 , 000 样本,并绘制直方图。
Report issue for preceding element
二维能量函数 首先,我们在几个合成的二维能量函数上比较了我们的模型和其他模型:MoG2(i)(具有相同 σ 2 = 0.5 \sigma^{2}=0.5italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = 0.5 或不同方差 σ 1 2 = 1.5 , σ 2 2 = 0.3 \sigma_{1}^{2}=1.5,\sigma_{2}^{2}=0.3italic_σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = 1.5 , italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = 0.3 的两个各向同性高斯的混合物,中心点之间的距离为 10 1010 )、MoG6(具有方差 σ 2 = 0.1 \sigma^{2}=0.1italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = 0.1 的六个各向同性高斯的混合物)、MoG9(方差为 σ 2 = 0.3 \sigma^{2}=0.3italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = 0.3 的九个各向同性高斯的混集)、Ring、Ring5(能量函数见[12 ] )。我们在图 2 中展示了样本直方图,以供目测,表 1 总结了采样误差。如图所示,与其他方法相比,EDG 提供的样本质量更高。为了阐明 EDG 中每个组件的功能,我们将我们的模型与 F 中的 vanilla VAE 进行了比较。
Report issue for preceding element
贝叶斯逻辑回归 在随后的实验中,我们将重点介绍 EDG 在贝叶斯逻辑回归中的功效,尤其是在处理位于高维空间中的后验分布时。在这种情况下,我们处理的是一个二元分类问题,其标签为 L = { 0 , 1 } L=\{0,1\}italic_L = { 0 , 1 } ,高维特征为 D Ditalic_D 。分类器的输出定义为
Report issue for preceding element
p ( L = 1 | D , x ) = softmax ( w ⊤ D + b ) , 𝑝 𝐿 conditional 1 𝐷 𝑥
softmax superscript 𝑤 top 𝐷 𝑏 p(L=1|D,x)=\mathrm{softmax}(w^{\top}D+b), italic_p ( italic_L = 1 | italic_D , italic_x ) = roman_softmax ( italic_w start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_D + italic_b ) ,
其中 x = ( w , b ) x=(w,b)italic_x = ( italic_w , italic_b ) 。我们的目标是从后验分布中抽取样本 x 1 , … , x N x^{1},\ldots,x^{N}italic_x start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , … , italic_x start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT
Report issue for preceding element
π ( x ) ∝ p ( x ) ∏ ( L , D ) ∈ 𝒟 train p ( L | D , x ) proportional-to 𝜋 𝑥 𝑝 𝑥 subscript product 𝐿 𝐷 subscript 𝒟 train 𝑝 conditional 𝐿 𝐷 𝑥
\pi(x)\propto p(x)\prod_{(L,D)\in\mathcal{D}_{\mathrm{train}}}p(L|D,x) italic_π ( italic_x ) ∝ italic_p ( italic_x ) ∏ start_POSTSUBSCRIPT ( italic_L , italic_D ) ∈ caligraphic_D start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_p ( italic_L | italic_D , italic_x )
基于训练集 𝒟 train \mathcal{D}_{\mathrm{train}}caligraphic_D start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT ,其中先验分布 p ( x ) p(x)italic_p ( italic_x ) 是标准高斯分布。然后,对于给定的 D Ditalic_D ,条件分布 p ( L | D , 𝒟 train ) p(L|D,\mathcal{D}_{\mathrm{train}})italic_p ( italic_L | italic_D , caligraphic_D start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT ) 可以近似为 ∑ n p ( L | D , x n ) \sum_{n}p(L|D,x^{n})∑ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT italic_p ( italic_L | italic_D , italic_x start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ) 。我们在三个数据集上进行了实验:[52 ] ,评估测试子集的准确率 (ACC) 和曲线下面积 (AUC)。值得注意的是,如表 2 所示,EDG 始终保持着最高的准确率和 AUC 性能。
Report issue for preceding element
表 2: 贝叶斯逻辑回归任务的分类准确率和 AUC 结果。实验采用一致的训练和测试数据分区,其中 HMC 步长设置为 0.01 0.010.01 。平均准确率和 AUC 值以及各自的标准偏差是在所有数据集的 32 3232 独立实验中计算得出的。
Report issue for preceding element
我们将分析扩展到由 581 012 个数据点和 54 个特征组成的二元 Covertype 数据集。分类器参数的后验遵循分层贝叶斯模型(见 [27 ] 的第 5 章),其中 x xitalic_x 表示分类器参数和分层贝叶斯模型中超参数的组合。为了提高计算效率,在 BG 和 EDG 中, log π ( x ) \log\pi(x)roman_log italic_π ( italic_x ) 在训练过程中无偏近似为
Report issue for preceding element
log π ( x ) ≈ log p ( x ) + | 𝒟 train | | ℬ | ∑ ( L , D ) ∈ ℬ log p ( L | D , x ) , 𝜋 𝑥 𝑝 𝑥 subscript 𝒟 train ℬ subscript 𝐿 𝐷 ℬ 𝑝 conditional 𝐿 𝐷 𝑥
\log\pi(x)\approx\log p(x)+\frac{|\mathcal{D}_{\mathrm{train}}|}{|\mathcal{B}|%
}\sum_{(L,D)\in\mathcal{B}}\log p(L|D,x), roman_log italic_π ( italic_x ) ≈ roman_log italic_p ( italic_x ) + divide start_ARG | caligraphic_D start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT | end_ARG start_ARG | caligraphic_B | end_ARG ∑ start_POSTSUBSCRIPT ( italic_L , italic_D ) ∈ caligraphic_B end_POSTSUBSCRIPT roman_log italic_p ( italic_L | italic_D , italic_x ) ,
其中 ℬ \mathcal{B}caligraphic_B 为随机小批量。对于 V-HMC 和 L2HMC,计算的是精确的后验密度。表 3 中的结果表明,EDG 始终优于其他方法。
Report issue for preceding element
表 3: Coverstype 测试数据集的分类准确率。报告值代表 32 次独立实验的平均准确率和标准偏差。
Report issue for preceding element
表 4: 通过第 3.3 节中描述的方法,获得了维数为 256 的二维等化模型中 log Z Ising \log Z_{\mathrm{Ising}}roman_log italic_Z start_POSTSUBSCRIPT roman_Ising end_POSTSUBSCRIPT 的估计值( 16 × 16 16\times 1616 × 16 )。我们利用 n = 256 n=256italic_n = 256 的批量大小来估计平均值,并应用中心极限定理计算出统计量平均值的标准偏差为 std / n \mathrm{std}/\sqrt{n}roman_std / square-root start_ARG italic_n end_ARG 。
Report issue for preceding element
伊辛模型 最后,我们验证了 EDG 在二维伊辛模型 [17 ] 上的性能,该模型是统计力学中的铁磁数学模型。为确保物理变量的连续性,我们采用连续松弛技巧 [53 ] 将离散变量转换为具有目标分布的连续辅助变量:
Report issue for preceding element
π ( 𝒙 ) = exp ( − 1 2 𝒙 T ( K ( T ) + α I ) − 1 𝒙 ) × ∏ i = 1 N cosh ( x i ) , 𝜋 𝒙 1 2 superscript 𝒙 𝑇 superscript 𝐾 𝑇 𝛼 𝐼 1 𝒙 superscript subscript product 𝑖 1 𝑁 subscript 𝑥 𝑖 \pi(\boldsymbol{x})=\exp\left(-\frac{1}{2}\boldsymbol{x}^{T}\left(K(T)+\alpha I%
\right)^{-1}\boldsymbol{x}\right)\times\prod_{i=1}^{N}\cosh\left(x_{i}\right), italic_π ( bold_italic_x ) = roman_exp ( - divide start_ARG 1 end_ARG start_ARG 2 end_ARG bold_italic_x start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( italic_K ( italic_T ) + italic_α italic_I ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_x ) × ∏ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT roman_cosh ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ,
其中 K Kitalic_K 是一个与温度 T Titalic_T 有关的 N × N N\times Nitalic_N × italic_N 对称矩阵, α \alphaitalic_α 是一个保证 K + α I K+\alpha Iitalic_K + italic_α italic_I 为正值的常数。对于相应的离散伊辛变量 𝒔 = { 1 , − 1 } ⊗ N \boldsymbol{s}=\{1,-1\}^{\otimes N}bold_italic_s = { 1 , - 1 } start_POSTSUPERSCRIPT ⊗ italic_N end_POSTSUPERSCRIPT ,可以根据 π ( 𝒔 | 𝒙 ) = ∏ i ( 1 + e − 2 s i x i ) − 1 \pi(\boldsymbol{s}|\boldsymbol{x})=\prod_{i}\left(1+e^{-2s_{i}x_{i}}\right)^{-1}italic_π ( bold_italic_s | bold_italic_x ) = ∏ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( 1 + italic_e start_POSTSUPERSCRIPT - 2 italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT 直接得到离散样本。当没有外部磁场,每个自旋只能与相邻的自旋相互作用时, K Kitalic_K 定义为 ∑ < i j > s i s j / T \sum_{<ij>}s_{i}s_{j}/T∑ start_POSTSUBSCRIPT < italic_i italic_j > end_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT / italic_T ,近邻和为 < i j > <ij>< italic_i italic_j > 。因此,连续弛豫系统的归一化常数由 log Z = log Z Ising + 1 2 ln det ( K + α I ) − N 2 [ ln ( 2 / π ) − α ] \log Z=\log Z_{\text{Ising }}+\frac{1}{2}\ln\operatorname{det}(K+\alpha I)-%
\frac{N}{2}[\ln(2/\pi)-\alpha]roman_log italic_Z = roman_log italic_Z start_POSTSUBSCRIPT Ising end_POSTSUBSCRIPT + divide start_ARG 1 end_ARG start_ARG 2 end_ARG roman_ln roman_det ( italic_K + italic_α italic_I ) - divide start_ARG italic_N end_ARG start_ARG 2 end_ARG [ roman_ln ( 2 / italic_π ) - italic_α ] [17 ] 给出。此外,使用第 3.3 节中描述的方法,我们为 NeuralRG、PIS 和 EDG 生成的样本提供了不同温度下 log Z Ising \log Z_{\text{Ising}}roman_log italic_Z start_POSTSUBSCRIPT Ising end_POSTSUBSCRIPT 的下限估计值。由于这些是下限估计值,因此数值越大表示结果越精确。如表 4 所示,在大多数温度范围内,EDG 提供了最准确的 log Z \log Zroman_log italic_Z 估计值。图 3 显示了不同温度下生成的状态。
Report issue for preceding element
图 3: 维数为 256 的 EDG 在 T = 2.0 T=2.0italic_T = 2.0 到 T = 2.7 T=2.7italic_T = 2.7 的不同温度下生成的状态( 16 × 16 16\times 1616 × 16 ),其中潜变量 z 0 z_{0}italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT 保持不变。随着温度的升高,模型的状态逐渐趋于无序。
Report issue for preceding element
5 Conclusion5 结论 Report issue for preceding element
总之,我们的工作结合了基于 VI 和扩散方法的原理,将 EDG 介绍为一种创新而有效的采样方法。EDG 从 VAE 中汲取灵感,擅长从错综复杂的玻尔兹曼分布中高效生成样本。利用扩散模型的表现力,我们的方法无需对常微分方程或随机微分方程进行数值求解,就能准确估计 KL 发散。经验实验验证了 EDG 的卓越采样性能。
Report issue for preceding element
考虑到其强大的生成能力和理论上不受限制的网络设计,仍有进一步探索的空间。我们可以针对不同的任务设计特定的网络结构,找到最精细的网络结构。尽管如此,它还是首次尝试在扩散模型的辅助下设计单次生成器。
Report issue for preceding element
今后,我们的主要重点将是扩展 EDG 的应用,为蛋白质等大规模物理和化学系统构建生成模型 [54 , 55 ] .
Report issue for preceding element
致谢 Report issue for preceding element
第一和第三作者受中国国家自然科学基金资助(资助号:12171367)。第二作者得到中国国家自然科学基金(批准号:92270115、12071301)、上海市科委(批准号:20JC1412500)和河南省科学院的资助。最后一位作者得到中国国家自然科学基金(批准号:12288201)、中国科学院战略性先导科技专项(批准号:XDA25010404)、国家重点研发计划(2020YFA0712000)、中科院青年创新促进会和河南省科学院的资助。
Report issue for preceding element
附录 A z T z_{T}italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT 和 x xitalic_x 之间独立性的证明
Report issue for preceding element
在解码过程中,如果 z T z_{T}italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT 与 z 0 z_{0}italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT 无关,那么我们有
Report issue for preceding element
p D ( z T | x ) subscript 𝑝 𝐷 conditional subscript 𝑧 𝑇 𝑥 \displaystyle p_{D}(z_{T}|x) italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT | italic_x )
= \displaystyle= =
∫ p D ( z T | z 0 ) p D ( z 0 | x ) d z 0 subscript 𝑝 𝐷 conditional subscript 𝑧 𝑇 subscript 𝑧 0 subscript 𝑝 𝐷 conditional subscript 𝑧 0 𝑥 differential-d subscript 𝑧 0 \displaystyle\int p_{D}(z_{T}|z_{0})p_{D}(z_{0}|x)\mathrm{d}z_{0} ∫ italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x ) roman_d italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT
= \displaystyle= =
p D ( z T ) ⋅ ∫ p D ( z 0 | x ) d z 0 ⋅ subscript 𝑝 𝐷 subscript 𝑧 𝑇 subscript 𝑝 𝐷 conditional subscript 𝑧 0 𝑥 differential-d subscript 𝑧 0 \displaystyle p_{D}(z_{T})\cdot\int p_{D}(z_{0}|x)\mathrm{d}z_{0} italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) ⋅ ∫ italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x ) roman_d italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT
= \displaystyle= =
p D ( z T ) . subscript 𝑝 𝐷 subscript 𝑧 𝑇 \displaystyle p_{D}(z_{T}). italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) .
Appendix B Proof of (5 ) Report issue for preceding element
对于极小的滞后时间 τ \tauitalic_τ ,解码过程中 z [ ⋅ ] z_{[\cdot]}italic_z start_POSTSUBSCRIPT [ ⋅ ] end_POSTSUBSCRIPT 的欧拉-马鲁山离散化提供了
Report issue for preceding element
z t subscript 𝑧 𝑡 \displaystyle z_{t} italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT
= \displaystyle= =
z t − τ + f ( z t − τ , t − τ ) τ + τ g ( t − τ ) u t − τ , subscript 𝑧 𝑡 𝜏 𝑓 subscript 𝑧 𝑡 𝜏 𝑡 𝜏 𝜏 𝜏 𝑔 𝑡 𝜏 subscript 𝑢 𝑡 𝜏 \displaystyle z_{t-\tau}+f(z_{t-\tau},t-\tau)\tau+\sqrt{\tau}g(t-\tau)u_{t-%
\tau}, italic_z start_POSTSUBSCRIPT italic_t - italic_τ end_POSTSUBSCRIPT + italic_f ( italic_z start_POSTSUBSCRIPT italic_t - italic_τ end_POSTSUBSCRIPT , italic_t - italic_τ ) italic_τ + square-root start_ARG italic_τ end_ARG italic_g ( italic_t - italic_τ ) italic_u start_POSTSUBSCRIPT italic_t - italic_τ end_POSTSUBSCRIPT ,
(9)
z t − τ subscript 𝑧 𝑡 𝜏 \displaystyle z_{t-\tau} italic_z start_POSTSUBSCRIPT italic_t - italic_τ end_POSTSUBSCRIPT
= \displaystyle= =
z t − ( f ( z t , t ) − g ( t ) 2 ∇ log p D ( z t ) ) τ + τ g ( t ) u ¯ t , subscript 𝑧 𝑡 𝑓 subscript 𝑧 𝑡 𝑡 𝑔 superscript 𝑡 2 ∇ subscript 𝑝 𝐷 subscript 𝑧 𝑡 𝜏 𝜏 𝑔 𝑡 subscript ¯ 𝑢 𝑡 \displaystyle z_{t}-\left(f(z_{t},t)-g(t)^{2}\nabla\log p_{D}(z_{t})\right)%
\tau+\sqrt{\tau}g(t)\bar{u}_{t}, italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - ( italic_f ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - italic_g ( italic_t ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) italic_τ + square-root start_ARG italic_τ end_ARG italic_g ( italic_t ) over¯ start_ARG italic_u end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ,
(10)
其中, u t − τ , u ¯ t ∼ 𝒩 ( ⋅ | 0 , I ) u_{t-\tau},\bar{u}_{t}\sim\mathcal{N}(\cdot|0,I)italic_u start_POSTSUBSCRIPT italic_t - italic_τ end_POSTSUBSCRIPT , over¯ start_ARG italic_u end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∼ caligraphic_N ( ⋅ | 0 , italic_I ) 。重要的是, u t − τ u_{t-\tau}italic_u start_POSTSUBSCRIPT italic_t - italic_τ end_POSTSUBSCRIPT 与 { z t ′ | t ′ ≤ t − τ } \{z_{t^{\prime}}|t^{\prime}\leq t-\tau\}{ italic_z start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT | italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≤ italic_t - italic_τ } 和 x xitalic_x 无关,而 u ¯ t \bar{u}_{t}over¯ start_ARG italic_u end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT 与 { z t ′ | t ′ ≥ t } \{z_{t^{\prime}}|t^{\prime}\geq t\}{ italic_z start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT | italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≥ italic_t } 无关。通过对 p E ( z [ ⋅ ] | x ) p_{E}(z_{[\cdot]}|x)italic_p start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT [ ⋅ ] end_POSTSUBSCRIPT | italic_x ) 的欧拉-马鲁山近似计算,可以得出
Report issue for preceding element
z t − τ = z t − ( f ( z t , t ) − g ( t ) 2 s ( z t , x , t ) ) τ + τ g ( t ) v t subscript 𝑧 𝑡 𝜏 subscript 𝑧 𝑡 𝑓 subscript 𝑧 𝑡 𝑡 𝑔 superscript 𝑡 2 𝑠 subscript 𝑧 𝑡 𝑥 𝑡 𝜏 𝜏 𝑔 𝑡 subscript 𝑣 𝑡 z_{t-\tau}=z_{t}-\left(f(z_{t},t)-g(t)^{2}s\left(z_{t},x,t\right)\right)\tau+%
\sqrt{\tau}g(t)v_{t} italic_z start_POSTSUBSCRIPT italic_t - italic_τ end_POSTSUBSCRIPT = italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - ( italic_f ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - italic_g ( italic_t ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_s ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x , italic_t ) ) italic_τ + square-root start_ARG italic_τ end_ARG italic_g ( italic_t ) italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT
(11)
with v t ∼ 𝒩 ( ⋅ | 0 , I ) v_{t}\sim\mathcal{N}(\cdot|0,I) italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∼ caligraphic_N ( ⋅ | 0 , italic_I ) . 用 v t ∼ 𝒩 ( ⋅ | 0 , I ) v_{t}\sim\mathcal{N}(\cdot|0,I)italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∼ caligraphic_N ( ⋅ | 0 , italic_I ) 。
Report issue for preceding element
根据(10 , 11 ),我们可以得到
Report issue for preceding element
𝔼 p D [ log p D ( z t − τ | z t ) ] subscript 𝔼 subscript 𝑝 𝐷 delimited-[] subscript 𝑝 𝐷 conditional subscript 𝑧 𝑡 𝜏 subscript 𝑧 𝑡 \displaystyle\mathbb{E}_{p_{D}}\left[\log p_{D}(z_{t-\tau}|z_{t})\right] blackboard_E start_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_t - italic_τ end_POSTSUBSCRIPT | italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ]
= \displaystyle= =
− d 2 log 2 π − d 2 log τ g ( t ) 2 𝑑 2 2 𝜋 𝑑 2 𝜏 𝑔 superscript 𝑡 2 \displaystyle-\frac{d}{2}\log 2\pi-\frac{d}{2}\log\tau g(t)^{2} - divide start_ARG italic_d end_ARG start_ARG 2 end_ARG roman_log 2 italic_π - divide start_ARG italic_d end_ARG start_ARG 2 end_ARG roman_log italic_τ italic_g ( italic_t ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
(12)
− 1 2 τ − 1 g ( t ) − 2 𝔼 p D [ ‖ z t − τ − z t + ( f