License: arXiv.org perpetual non-exclusive license
arXiv:2305.03935v4 [cs.LG] 06 Apr 2024

Improved Techniques for Maximum Likelihood Estimation for Diffusion ODEs

Kaiwen Zheng    Cheng Lu    Jianfei Chen    Jun Zhu
Abstract

Diffusion models have exhibited excellent performance in various domains. The probability flow ordinary differential equation (ODE) of diffusion models (i.e., diffusion ODEs) is a particular case of continuous normalizing flows (CNFs), which enables deterministic inference and exact likelihood evaluation. However, the likelihood estimation results by diffusion ODEs are still far from those of the state-of-the-art likelihood-based generative models. In this work, we propose several improved techniques for maximum likelihood estimation for diffusion ODEs, including both training and evaluation perspectives. For training, we propose velocity parameterization and explore variance reduction techniques for faster convergence. We also derive an error-bounded high-order flow matching objective for finetuning, which improves the ODE likelihood and smooths its trajectory. For evaluation, we propose a novel training-free truncated-normal dequantization to fill the training-evaluation gap commonly existing in diffusion ODEs. Building upon these techniques, we achieve state-of-the-art likelihood estimation results on image datasets (2.56 on CIFAR-10, 3.43/3.69 on ImageNet-32) without variational dequantization or data augmentation, and 2.42 on CIFAR-10 with data augmentation. Code is available at https://github.com/thu-ml/i-DODE.

Machine Learning, ICML

1 Introduction

Likelihood is an important metric to evaluate density estimation models, and accurate likelihood estimation is the key for many applications such as data compression (Ho et al., 2021; Helminger et al., 2020; Kingma et al., 2021; Yang & Mandt, 2022), anomaly detection (Chen et al., 2018c; Dias et al., 2020) and out-of-distribution detection (Serrà et al., 2020; Xiao et al., 2020). Many deep generative models can compute tractable likelihood, including autoregressive models (Oord et al., 2016; Salimans et al., 2017; Chen et al., 2018b), variational auto-encoders (VAE) (Kingma & Welling, 2014; Vahdat & Kautz, 2020), normalizing flows  (Dinh et al., 2017; Kingma & Dhariwal, 2018; Ho et al., 2019) and diffusion models (Sohl-Dickstein et al., 2015; Song & Ermon, 2019; Ho et al., 2020; Song et al., 2021c, a; Karras et al., 2022). Among these models, recent work named variational diffusion models (VDM) (Kingma et al., 2021) achieves state-of-the-art likelihood estimation performance on standard image density estimation benchmarks, which is a variant of diffusion models.

There are two types of diffusion models, one is based on the reverse stochastic differential equation (SDE) (Song et al., 2021c), named as diffusion SDE; the other is based on the probability flow ordinary differential equation (ODE) (Song et al., 2021c), named as diffusion ODE. These two types of diffusion models define and evaluate the likelihood in different manners: diffusion SDE can be understood as an infinitely-deep VAE (Huang et al., 2021) and can only compute a variational lower bound of the likelihood (Song et al., 2021c; Kingma et al., 2021); while diffusion ODE is a variant of continuous normalizing flows (Chen et al., 2018a) and can compute the exact likelihood by ODE solvers. Thus, it is natural to hypothesize that the likelihood performance of diffusion ODEs may be better than that of diffusion SDEs. However, all existing methods for training diffusion ODEs (Song et al., 2021b; Lu et al., 2022a; Lipman et al., 2022; Albergo & Vanden-Eijnden, 2022; Liu et al., 2022b) cannot even achieve a comparable likelihood performance with VDM, which belongs to diffusion SDEs. It still remains largely open whether diffusion ODEs are also great likelihood estimators.

Real-world data is usually discrete, and evaluating the likelihood of discrete data by diffusion ODEs needs to first perform a dequantization process (Dinh et al., 2017; Salimans et al., 2017) to make sure the input data of diffusion ODEs is continuous. In this work, we observe that previous likelihood evaluation of diffusion ODEs has flaws in the dequantization process: the uniform dequantization (Song et al., 2021b) causes a large training-evaluation gap, and the variational dequantization (Ho et al., 2019; Song et al., 2021b) requires additional training overhead and is hard to train to the optimal.

In this work, we propose several improved techniques, including both the evaluation perspective and training perspective, to allow the likelihood estimation by diffusion ODEs to outperform the existing state-of-the-art likelihood estimators. In the aspect of evaluation, we propose a training-free dequantization method dedicated to diffusion models by a carefully designed truncated-normal distribution, which can fit diffusion ODEs well and improve the likelihood evaluation by a large margin compared to uniform dequantization. We also introduce an importance-weighted likelihood estimator to get a tighter bound. In the aspect of training, we split our training into pretraining and finetuning phases. For pretraining, we propose a new model parameterization method including velocity parameterization, which is an extended version of flow matching (Lipman et al., 2022) with practical modifications, and log-signal-to-noise-ratio timed parameterization. Besides, we find a simple yet efficient importance sampling strategy for variance reduction. Together, our pretraining has a faster convergence speed compared to previous work. For finetuning, we propose an error-bounded high-order flow matching objective, which not only improves the ODE likelihood but also results in smoother trajectories. Together, we name our framework Improved Diffusion ODE (i-DODE).

We conduct ablation studies to demonstrate the effectiveness of separate parts. Our experimental results empirically achieve the state-of-the-art likelihood on image datasets (2.56 on CIFAR-10, 3.43/3.69 on ImageNet-32), surpassing the previous best ODEs of 2.90 and 3.48/3.82, with the superiority that we use no data augmentation and throw away the need for training variational dequantization models.

2 Diffusion Models

2.1 Diffusion ODEs and Maximum Likelihood Training

Suppose we have a d𝑑ditalic_d-dimensional data distribution q0(𝒙0)subscript𝑞0subscript𝒙0q_{0}(\bm{x}_{0})italic_q start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ). Diffusion models (Ho et al., 2020; Song et al., 2021c) gradually diffuse the data by a forward stochastic differential equation (SDE) starting from 𝒙0q0(𝒙0)similar-tosubscript𝒙0subscript𝑞0subscript𝒙0\bm{x}_{0}\sim q_{0}(\bm{x}_{0})bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_q start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ):

d𝒙t=f(t)𝒙tdt+g(t)d𝒘t,𝒙0q0(𝒙0),formulae-sequencedsubscript𝒙𝑡𝑓𝑡subscript𝒙𝑡d𝑡𝑔𝑡dsubscript𝒘𝑡similar-tosubscript𝒙0subscript𝑞0subscript𝒙0\mathrm{d}\bm{x}_{t}=f(t)\bm{x}_{t}\mathrm{d}t+g(t)\mathrm{d}\bm{w}_{t},\quad% \bm{x}_{0}\sim q_{0}(\bm{x}_{0}),roman_d bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_f ( italic_t ) bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT roman_d italic_t + italic_g ( italic_t ) roman_d bold_italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_q start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) , (1)

where f(t),g(t)𝑓𝑡𝑔𝑡f(t),g(t)\in\mathbb{R}italic_f ( italic_t ) , italic_g ( italic_t ) ∈ blackboard_R are manually designed noise schedules and 𝒘tdsubscript𝒘𝑡superscript𝑑\bm{w}_{t}\in\mathbb{R}^{d}bold_italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT is a standard Wiener process. The forward process {𝒙t}t[0,T]subscriptsubscript𝒙𝑡𝑡0𝑇\{\bm{x}_{t}\}_{t\in[0,T]}{ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_t ∈ [ 0 , italic_T ] end_POSTSUBSCRIPT is accompanied with a series of marginal distributions {qt}t[0,T]subscriptsubscript𝑞𝑡𝑡0𝑇\{q_{t}\}_{t\in[0,T]}{ italic_q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_t ∈ [ 0 , italic_T ] end_POSTSUBSCRIPT, so that qT(𝒙T)𝒩(𝒙T|𝟎,σT2𝑰)subscript𝑞𝑇subscript𝒙𝑇𝒩conditionalsubscript𝒙𝑇0superscriptsubscript𝜎𝑇2𝑰q_{T}(\bm{x}_{T})\approx\mathcal{N}(\bm{x}_{T}|\bm{0},\sigma_{T}^{2}\bm{I})italic_q start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) ≈ caligraphic_N ( bold_italic_x start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT | bold_0 , italic_σ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) with some constant σT>0subscript𝜎𝑇0\sigma_{T}>0italic_σ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT > 0. Since this is a simple linear SDE, the transition kernel is an analytical Gaussian (Song et al., 2021c): q0t(𝒙t|𝒙0)=𝒩(αt𝒙0,σt2𝑰)subscript𝑞0𝑡conditionalsubscript𝒙𝑡subscript𝒙0𝒩subscript𝛼𝑡subscript𝒙0superscriptsubscript𝜎𝑡2𝑰q_{0t}(\bm{x}_{t}|\bm{x}_{0})=\mathcal{N}(\alpha_{t}\bm{x}_{0},\sigma_{t}^{2}% \bm{I})italic_q start_POSTSUBSCRIPT 0 italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = caligraphic_N ( italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ), where the coefficients satisfy f(t)=dlogαtdt𝑓𝑡dsubscript𝛼𝑡d𝑡f(t)=\frac{\mathrm{d}\log\alpha_{t}}{\mathrm{d}t}italic_f ( italic_t ) = divide start_ARG roman_d roman_log italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG roman_d italic_t end_ARG, g2(t)=dσt2dt2dlogαtdtσt2superscript𝑔2𝑡dsuperscriptsubscript𝜎𝑡2d𝑡2dsubscript𝛼𝑡d𝑡superscriptsubscript𝜎𝑡2g^{2}(t)=\frac{\mathrm{d}\sigma_{t}^{2}}{\mathrm{d}t}-2\frac{\mathrm{d}\log% \alpha_{t}}{\mathrm{d}t}\sigma_{t}^{2}italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) = divide start_ARG roman_d italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG roman_d italic_t end_ARG - 2 divide start_ARG roman_d roman_log italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG roman_d italic_t end_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (Kingma et al., 2021). Under some regularity conditions (Anderson, 1982), the forward process has an equivalent probability flow ODE (Song et al., 2021c):

d𝒙tdt=f(t)𝒙t12g2(t)𝒙logqt(𝒙t),dsubscript𝒙𝑡d𝑡𝑓𝑡subscript𝒙𝑡12superscript𝑔2𝑡subscript𝒙subscript𝑞𝑡subscript𝒙𝑡\frac{\mathrm{d}\bm{x}_{t}}{\mathrm{d}t}=f(t)\bm{x}_{t}-\frac{1}{2}g^{2}(t)% \nabla_{\bm{x}}\log q_{t}(\bm{x}_{t}),divide start_ARG roman_d bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG roman_d italic_t end_ARG = italic_f ( italic_t ) bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) ∇ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT roman_log italic_q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) , (2)

which produces the same marginal distribution qtsubscript𝑞𝑡q_{t}italic_q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT at each time t𝑡titalic_t as that in Eqn. (1). The only unknown term 𝒙logqt(𝒙t)subscript𝒙subscript𝑞𝑡subscript𝒙𝑡\nabla_{\bm{x}}\log q_{t}(\bm{x}_{t})∇ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT roman_log italic_q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) is the score function of qtsubscript𝑞𝑡q_{t}italic_q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. By parameterizing a score network 𝒔θ(𝒙t,t)subscript𝒔𝜃subscript𝒙𝑡𝑡\bm{s}_{\theta}(\bm{x}_{t},t)bold_italic_s start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) to predict the time-dependent 𝒙logqt(𝒙t)subscript𝒙subscript𝑞𝑡subscript𝒙𝑡\nabla_{\bm{x}}\log q_{t}(\bm{x}_{t})∇ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT roman_log italic_q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ), we can replace the true score function, resulting in the diffusion ODE (Song et al., 2021c):

d𝒙tdt=f(t)𝒙t12g2(t)𝒔θ(𝒙t,t),dsubscript𝒙𝑡d𝑡𝑓𝑡subscript𝒙𝑡12superscript𝑔2𝑡subscript𝒔𝜃subscript𝒙𝑡𝑡\frac{\mathrm{d}\bm{x}_{t}}{\mathrm{d}t}=f(t)\bm{x}_{t}-\frac{1}{2}g^{2}(t)\bm% {s}_{\theta}(\bm{x}_{t},t),divide start_ARG roman_d bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG roman_d italic_t end_ARG = italic_f ( italic_t ) bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) bold_italic_s start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) , (3)

with the associated marginal distributions {pt}t[0,T]subscriptsubscript𝑝𝑡𝑡0𝑇\{p_{t}\}_{t\in[0,T]}{ italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_t ∈ [ 0 , italic_T ] end_POSTSUBSCRIPT. Diffusion ODEs are special cases of continuous normalizing flows (CNFs) (Chen et al., 2018a), thus can perform exact inference of the latents and exact likelihood evaluation.

Though traditional maximum likelihood training methods for CNFs (Grathwohl et al., 2019) are feasible for diffusion ODEs, the training costs of these methods are quite expensive and hard to scale up because of the requirement of solving ODEs at each iteration. Instead, a more practical way is to match the generative probability flow {pt}t[0,T]subscriptsubscript𝑝𝑡𝑡0𝑇\{p_{t}\}_{t\in[0,T]}{ italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_t ∈ [ 0 , italic_T ] end_POSTSUBSCRIPT with {qt}t[0,T]subscriptsubscript𝑞𝑡𝑡0𝑇\{q_{t}\}_{t\in[0,T]}{ italic_q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_t ∈ [ 0 , italic_T ] end_POSTSUBSCRIPT by a simulation-free approach. Specifically, Lu et al. (2022a) proves that DKL(q0p0ODE)subscript𝐷KLconditionalsubscript𝑞0superscriptsubscript𝑝0ODED_{\mathrm{KL}}(q_{0}\;\|\;p_{0}^{\text{ODE}})italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT ( italic_q start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∥ italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ODE end_POSTSUPERSCRIPT ) can be formulated by DKL(q0p0ODE)=DKL(qTpTODE)+𝒥ODE(θ)subscript𝐷KLconditionalsubscript𝑞0subscriptsuperscript𝑝ODE0subscript𝐷KLconditionalsubscript𝑞𝑇subscriptsuperscript𝑝ODE𝑇subscript𝒥ODE𝜃D_{\mathrm{KL}}(q_{0}\;\|\;p^{\text{ODE}}_{0})\!=\!D_{\mathrm{KL}}(q_{T}\;\|\;% p^{\text{ODE}}_{T})+\mathcal{J}_{\text{ODE}}(\theta)italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT ( italic_q start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∥ italic_p start_POSTSUPERSCRIPT ODE end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT ( italic_q start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ∥ italic_p start_POSTSUPERSCRIPT ODE end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) + caligraphic_J start_POSTSUBSCRIPT ODE end_POSTSUBSCRIPT ( italic_θ ), where

𝒥ODE(θ)subscript𝒥ODE𝜃\displaystyle\mathcal{J}_{\text{ODE}}(\theta)\!caligraphic_J start_POSTSUBSCRIPT ODE end_POSTSUBSCRIPT ( italic_θ ) 120Tg(t)2𝔼qt(𝒙t)[(𝒔θ(𝒙t,t)𝒙logqt(𝒙t))\displaystyle\coloneqq\!\frac{1}{2}\!\int_{0}^{T}\!\!\!g(t)^{2}\mathbb{E}_{q_{% t}(\bm{x}_{t})}\!\Big{[}(\bm{s}_{\theta}(\bm{x}_{t},t)\!-\!\nabla_{\bm{x}}\log q% _{t}(\bm{x}_{t}))^{\top}≔ divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_g ( italic_t ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ ( bold_italic_s start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - ∇ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT roman_log italic_q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT
(𝒙logpt(𝒙t)𝒙logqt(𝒙t))]dt\displaystyle(\nabla_{\bm{x}}\log p_{t}(\bm{x}_{t})\!-\!\nabla_{\bm{x}}\log q_% {t}(\bm{x}_{t}))\Big{]}\mathrm{d}t( ∇ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) - ∇ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT roman_log italic_q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) ] roman_d italic_t (4)

However, computing 𝒙logpt(𝒙t)subscript𝒙subscript𝑝𝑡subscript𝒙𝑡\nabla_{\bm{x}}\log p_{t}(\bm{x}_{t})∇ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) requires solving another ODE and is also expensive (Lu et al., 2022a). To minimize 𝒥ODE(θ)subscript𝒥ODE𝜃\mathcal{J}_{\text{ODE}}(\theta)caligraphic_J start_POSTSUBSCRIPT ODE end_POSTSUBSCRIPT ( italic_θ ) in a simulation-free manner, Lu et al. (2022a) also proposes a combination of g2(t)superscript𝑔2𝑡g^{2}(t)italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) weighted first-order and high-order score matching objectives. Particularly, the first-order score matching objective is

𝒥SM(θ)0Tg2(t)2σt2𝔼𝒙0,ϵ[σt𝒔θ(𝒙t,t)+ϵ22]dt,subscript𝒥SM𝜃superscriptsubscript0𝑇superscript𝑔2𝑡2superscriptsubscript𝜎𝑡2subscript𝔼subscript𝒙0bold-italic-ϵdelimited-[]superscriptsubscriptnormsubscript𝜎𝑡subscript𝒔𝜃subscript𝒙𝑡𝑡bold-italic-ϵ22differential-d𝑡\mathcal{J}_{\text{SM}}(\theta)\coloneqq\int_{0}^{T}\frac{g^{2}(t)}{2\sigma_{t% }^{2}}\mathbb{E}_{\bm{x}_{0},\bm{\epsilon}}\left[\|\sigma_{t}\bm{s}_{\theta}(% \bm{x}_{t},t)+\bm{\epsilon}\|_{2}^{2}\right]\mathrm{d}t,caligraphic_J start_POSTSUBSCRIPT SM end_POSTSUBSCRIPT ( italic_θ ) ≔ ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT divide start_ARG italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) end_ARG start_ARG 2 italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG blackboard_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ end_POSTSUBSCRIPT [ ∥ italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_s start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) + bold_italic_ϵ ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] roman_d italic_t , (5)

where 𝒙t=αt𝒙0+σtϵsubscript𝒙𝑡subscript𝛼𝑡subscript𝒙0subscript𝜎𝑡bold-italic-ϵ\bm{x}_{t}=\alpha_{t}\bm{x}_{0}+\sigma_{t}\bm{\epsilon}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_ϵ, 𝒙0q0(𝒙0)similar-tosubscript𝒙0subscript𝑞0subscript𝒙0\bm{x}_{0}\sim q_{0}(\bm{x}_{0})bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_q start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) and ϵ𝒩(ϵ|𝟎,𝑰)similar-tobold-italic-ϵ𝒩conditionalbold-italic-ϵ0𝑰\bm{\epsilon}\sim\mathcal{N}(\bm{\epsilon}|\bm{0},\bm{I})bold_italic_ϵ ∼ caligraphic_N ( bold_italic_ϵ | bold_0 , bold_italic_I ).

2.2 Log-SNR Timed Diffusion Models

Diffusion models have manually designed noise schedule αt,σtsubscript𝛼𝑡subscript𝜎𝑡\alpha_{t},\sigma_{t}italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, which has high freedom and affects the performance. Even for restricted design space such as Variance Preserving (VP) (Song et al., 2021c), which constrains the noise schedule by αt2+σt2=1superscriptsubscript𝛼𝑡2superscriptsubscript𝜎𝑡21\alpha_{t}^{2}+\sigma_{t}^{2}=1italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = 1, we could still have various choices about how fast αt,σtsubscript𝛼𝑡subscript𝜎𝑡\alpha_{t},\sigma_{t}italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT changes w.r.t time t𝑡titalic_t. To decouple the specific schedule form, variational diffusion models (VDM) (Kingma et al., 2021) use a negative log-signal-to-noise-ratio (log-SNR) for the time variable and can greatly simplify both noise schedules and training objectives. Specifically, denote γt=log-SNR(t)=logαt2σt2subscript𝛾𝑡log-SNR𝑡superscriptsubscript𝛼𝑡2superscriptsubscript𝜎𝑡2\gamma_{t}=-\text{log-SNR}(t)=-\log\frac{\alpha_{t}^{2}}{\sigma_{t}^{2}}italic_γ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = - log-SNR ( italic_t ) = - roman_log divide start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG, the change-of-variable relation from γ𝛾\gammaitalic_γ to t𝑡titalic_t is

dγdt=g2(t)σt2,d𝛾d𝑡superscript𝑔2𝑡superscriptsubscript𝜎𝑡2\frac{\mathrm{d}\gamma}{\mathrm{d}t}=\frac{g^{2}(t)}{\sigma_{t}^{2}},divide start_ARG roman_d italic_γ end_ARG start_ARG roman_d italic_t end_ARG = divide start_ARG italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG , (6)

and replace the time subscript with γ𝛾\gammaitalic_γ, we get the simplified score matching objective with likelihood weighting:

𝒥SM(θ)=12γ0γT𝔼𝒙0,ϵ[σt𝒔θ(𝒙γ,γ)+ϵ22]dγsubscript𝒥SM𝜃12superscriptsubscriptsubscript𝛾0subscript𝛾𝑇subscript𝔼subscript𝒙0bold-italic-ϵdelimited-[]superscriptsubscriptnormsubscript𝜎𝑡subscript𝒔𝜃subscript𝒙𝛾𝛾bold-italic-ϵ22differential-d𝛾\mathcal{J}_{\text{SM}}(\theta)=\frac{1}{2}\int_{\gamma_{0}}^{\gamma_{T}}% \mathbb{E}_{\bm{x}_{0},\bm{\epsilon}}\left[\|\sigma_{t}\bm{s}_{\theta}(\bm{x}_% {\gamma},\gamma)+\bm{\epsilon}\|_{2}^{2}\right]\mathrm{d}\gammacaligraphic_J start_POSTSUBSCRIPT SM end_POSTSUBSCRIPT ( italic_θ ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∫ start_POSTSUBSCRIPT italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_γ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ end_POSTSUBSCRIPT [ ∥ italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_s start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ) + bold_italic_ϵ ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] roman_d italic_γ (7)

This result is in accordance with the continuous diffusion loss in Kingma et al. (2021).

2.3 Dequantization for Density Estimation

Many real-world datasets usually contain discrete data, such as images or texts. In such cases, learning a continuous density model to these discrete data points will cause degenerate results (Uria et al., 2013) and cannot provide meaningful density estimations. A common solution is dequantization (Dinh et al., 2017; Salimans et al., 2017; Ho et al., 2019). Specifically, suppose 𝒙0subscript𝒙0\bm{x}_{0}bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is 8-bit discrete data scaled to [1,1]11[-1,1][ - 1 , 1 ]. Dequantization methods assume that we have trained a continuous model distribution pmodelsubscript𝑝modelp_{\text{model}}italic_p start_POSTSUBSCRIPT model end_POSTSUBSCRIPT for 𝒙0subscript𝒙0\bm{x}_{0}bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, and define the discrete model distribution by

Pmodel(𝒙0)[1256,1256)dpmodel(𝒙0+𝒖)d𝒖.subscript𝑃modelsubscript𝒙0subscriptsuperscript12561256𝑑subscript𝑝modelsubscript𝒙0𝒖differential-d𝒖P_{\text{model}}(\bm{x}_{0})\coloneqq\int_{[-\frac{1}{256},\frac{1}{256})^{d}}% p_{\text{model}}(\bm{x}_{0}+\bm{u})\mathrm{d}\bm{u}.italic_P start_POSTSUBSCRIPT model end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ≔ ∫ start_POSTSUBSCRIPT [ - divide start_ARG 1 end_ARG start_ARG 256 end_ARG , divide start_ARG 1 end_ARG start_ARG 256 end_ARG ) start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT model end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + bold_italic_u ) roman_d bold_italic_u .

To train Pmodel(𝒙0)subscript𝑃modelsubscript𝒙0P_{\text{model}}(\bm{x}_{0})italic_P start_POSTSUBSCRIPT model end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) by maximum likelihood estimation, variational dequantization (Ho et al., 2019) introduces a dequantization distribution q(𝒖|𝒙0)𝑞conditional𝒖subscript𝒙0q(\bm{u}|\bm{x}_{0})italic_q ( bold_italic_u | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) and jointly train pmodelsubscript𝑝modelp_{\text{model}}italic_p start_POSTSUBSCRIPT model end_POSTSUBSCRIPT and q(𝒖|𝒙0)𝑞conditional𝒖subscript𝒙0q(\bm{u}|\bm{x}_{0})italic_q ( bold_italic_u | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) by a variational lower bound:

logPmodel(𝒙0)𝔼q(𝒖|𝒙0)[logpmodel(𝒙0+𝒖)logq(𝒖|𝒙0)].subscript𝑃modelsubscript𝒙0subscript𝔼𝑞conditional𝒖subscript𝒙0delimited-[]subscript𝑝modelsubscript𝒙0𝒖𝑞conditional𝒖subscript𝒙0\log\!P_{\text{model}}(\bm{x}_{0})\!\geq\!\mathbb{E}_{q(\bm{u}|\bm{x}_{0})}\!% \left[\log p_{\text{model}}(\bm{x}_{0}\!+\!\bm{u})\!-\!\log q(\bm{u}|\bm{x}_{0% })\right].roman_log italic_P start_POSTSUBSCRIPT model end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ≥ blackboard_E start_POSTSUBSCRIPT italic_q ( bold_italic_u | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT model end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + bold_italic_u ) - roman_log italic_q ( bold_italic_u | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ] .

A simple way for q(𝒖|𝒙0)𝑞conditional𝒖subscript𝒙0q(\bm{u}|\bm{x}_{0})italic_q ( bold_italic_u | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) is uniform dequantization, where we set q(𝒖|𝒙0)=𝒰(1256,1256)𝑞conditional𝒖subscript𝒙0𝒰12561256q(\bm{u}|\bm{x}_{0})=\mathcal{U}(-\frac{1}{256},\frac{1}{256})italic_q ( bold_italic_u | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = caligraphic_U ( - divide start_ARG 1 end_ARG start_ARG 256 end_ARG , divide start_ARG 1 end_ARG start_ARG 256 end_ARG ).

3 Diffusion ODEs with Truncated-Normal Dequantization

In this section, we discuss the challenges of training diffusion ODEs with dequantization and propose a training-free dequantization method for diffusion ODEs.

3.1 Challenges for Diffusion ODEs with Dequantization

We first discuss the challenges for diffusion ODEs with dequantization in this section.

Truncation introduces an additional gap.

Theoretically, we want to train diffusion ODEs by minimizing DKL(q0p0)subscript𝐷KLconditionalsubscript𝑞0subscript𝑝0D_{\mathrm{KL}}(q_{0}\;\|\;p_{0})italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT ( italic_q start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∥ italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) and use p0(𝒙0)subscript𝑝0subscript𝒙0p_{0}(\bm{x}_{0})italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) for the continuous model distribution. However, as σ0=0subscript𝜎00\sigma_{0}=0italic_σ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 0, we have γ0=subscript𝛾0\gamma_{0}=-\inftyitalic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = - ∞. Due to this, it is shown in previous work (Song et al., 2021c; Kim et al., 2022) that there are numerical issues near t=0𝑡0t=0italic_t = 0 for both training and sampling, so we cannot directly compute the model distribution p0subscript𝑝0p_{0}italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT at time 00. In practice, a common solution is to choose a small starting time ϵ>0italic-ϵ0\epsilon>0italic_ϵ > 0 for improving numerical stability. The training objective then becomes minimizing DKL(qϵpϵ)subscript𝐷KLconditionalsubscript𝑞italic-ϵsubscript𝑝italic-ϵD_{\mathrm{KL}}(q_{\epsilon}\;\|\;p_{\epsilon})italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT ( italic_q start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ∥ italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ), which is equivalent to

maxθ𝔼q0(𝒙0)q0ϵ(𝒙ϵ|𝒙0)[logpϵ(𝒙ϵ)],subscript𝜃subscript𝔼subscript𝑞0subscript𝒙0subscript𝑞0italic-ϵconditionalsubscript𝒙italic-ϵsubscript𝒙0delimited-[]subscript𝑝italic-ϵsubscript𝒙italic-ϵ\max_{\theta}\mathbb{E}_{q_{0}(\bm{x}_{0})q_{0\epsilon}(\bm{x}_{\epsilon}|\bm{% x}_{0})}[\log p_{\epsilon}(\bm{x}_{\epsilon})],roman_max start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_q start_POSTSUBSCRIPT 0 italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) ] , (8)

and 𝔼q0(𝒙0)logpϵ(𝒙0)subscript𝔼subscript𝑞0subscript𝒙0subscript𝑝italic-ϵsubscript𝒙0\mathbb{E}_{q_{0}(\bm{x}_{0})}\log p_{\epsilon}(\bm{x}_{0})blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) is directly used to evaluate the data likelihood. However, as pϵp0subscript𝑝italic-ϵsubscript𝑝0p_{\epsilon}\neq p_{0}italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ≠ italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, such a method will introduce an additional gap due to the mismatch between training (𝔼qϵ(𝒙ϵ)[logpϵ(𝒙ϵ)]subscript𝔼subscript𝑞italic-ϵsubscript𝒙italic-ϵdelimited-[]subscript𝑝italic-ϵsubscript𝒙italic-ϵ\mathbb{E}_{q_{\epsilon}(\bm{x}_{\epsilon})}[\log p_{\epsilon}(\bm{x}_{% \epsilon})]blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) ]) and testing (𝔼q0(𝒙0)[logpϵ(𝒙0)]subscript𝔼subscript𝑞0subscript𝒙0delimited-[]subscript𝑝italic-ϵsubscript𝒙0\mathbb{E}_{q_{0}(\bm{x}_{0})}[\log p_{\epsilon}(\bm{x}_{0})]blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ]), which may degrade the likelihood evaluation performance.

Uniform dequantization causes a train-test mismatch.

After choosing ϵitalic-ϵ\epsilonitalic_ϵ, the continuous model distribution is defined by pmodel(𝒙)pϵ(𝒙)subscript𝑝model𝒙subscript𝑝italic-ϵ𝒙p_{\text{model}}(\bm{x})\coloneqq p_{\epsilon}(\bm{x})italic_p start_POSTSUBSCRIPT model end_POSTSUBSCRIPT ( bold_italic_x ) ≔ italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x ). Let q(𝒖|𝒙0)𝑞conditional𝒖subscript𝒙0q(\bm{u}|\bm{x}_{0})italic_q ( bold_italic_u | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) be a dequantization distribution with support over 𝒖[1256,1256)d𝒖superscript12561256𝑑\bm{u}\in[-\frac{1}{256},\frac{1}{256})^{d}bold_italic_u ∈ [ - divide start_ARG 1 end_ARG start_ARG 256 end_ARG , divide start_ARG 1 end_ARG start_ARG 256 end_ARG ) start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT. The variational lower bound for the discrete model density P0(𝒙0)subscript𝑃0subscript𝒙0P_{0}(\bm{x}_{0})italic_P start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) is:

𝔼q0(𝒙0)[logP0(𝒙0)]𝔼q0(𝒙0)q(𝒖|𝒙0)[logpϵ(𝒙0+𝒖)]𝔼q0(𝒙0)q(𝒖|𝒙0)[logq(𝒖|𝒙0)].subscript𝔼subscript𝑞0subscript𝒙0delimited-[]subscript𝑃0subscript𝒙0subscript𝔼subscript𝑞0subscript𝒙0𝑞conditional𝒖subscript𝒙0delimited-[]subscript𝑝italic-ϵsubscript𝒙0𝒖subscript𝔼subscript𝑞0subscript𝒙0𝑞conditional𝒖subscript𝒙0delimited-[]𝑞conditional𝒖subscript𝒙0\begin{split}\mathbb{E}_{q_{0}(\bm{x}_{0})}[\log P_{0}(\bm{x}_{0})]&\geq% \mathbb{E}_{q_{0}(\bm{x}_{0})q(\bm{u}|\bm{x}_{0})}\left[\log p_{\epsilon}(\bm{% x}_{0}+\bm{u})\right]\\ &-\mathbb{E}_{q_{0}(\bm{x}_{0})q(\bm{u}|\bm{x}_{0})}\left[\log q(\bm{u}|\bm{x}% _{0})\right].\end{split}start_ROW start_CELL blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ roman_log italic_P start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ] end_CELL start_CELL ≥ blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_q ( bold_italic_u | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + bold_italic_u ) ] end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL - blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_q ( bold_italic_u | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ roman_log italic_q ( bold_italic_u | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ] . end_CELL end_ROW

One widely used choice for q(𝒖|𝒙0)𝑞conditional𝒖subscript𝒙0q(\bm{u}|\bm{x}_{0})italic_q ( bold_italic_u | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) is uniform distribution (uniform dequantization). However, this leads to a training-evaluation gap: for training, we fit pϵsubscript𝑝italic-ϵp_{\epsilon}italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT to the distribution qϵ(𝒙ϵ)subscript𝑞italic-ϵsubscript𝒙italic-ϵq_{\epsilon}(\bm{x}_{\epsilon})italic_q start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ), which is a Gaussian distribution near each discrete data point 𝒙0subscript𝒙0\bm{x}_{0}bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT because 𝒙ϵ=αϵ𝒙0+σϵϵsubscript𝒙italic-ϵsubscript𝛼italic-ϵsubscript𝒙0subscript𝜎italic-ϵbold-italic-ϵ\bm{x}_{\epsilon}=\alpha_{\epsilon}\bm{x}_{0}+\sigma_{\epsilon}\bm{\epsilon}bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT = italic_α start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT bold_italic_ϵ for ϵ𝒩(𝟎,𝑰)similar-tobold-italic-ϵ𝒩0𝑰\bm{\epsilon}\sim\mathcal{N}(\bm{0},\bm{I})bold_italic_ϵ ∼ caligraphic_N ( bold_0 , bold_italic_I ); while for evaluation, we test pϵsubscript𝑝italic-ϵp_{\epsilon}italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT on uniform dequantized data 𝒙0+𝒖subscript𝒙0𝒖\bm{x}_{0}+\bm{u}bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + bold_italic_u. Such a gap will also degrade the likelihood evaluation performance and is not well-studied.

In addition, another way for dequantization is to train a variational dequantization model qϕ(𝒖|𝒙0)subscript𝑞italic-ϕconditional𝒖subscript𝒙0q_{\phi}(\bm{u}|\bm{x}_{0})italic_q start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT ( bold_italic_u | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) (Ho et al., 2019; Song et al., 2021b) but it will need additional costs and is hard to train (Kim et al., 2022).

3.2 Training-Free Dequantization by Truncated Normal

In this section, we show that there exists a training-free dequantization distribution that fits diffusion ODEs well.

As discussed in Sec. 3.1, the gap between training and testing of diffusion ODEs is due to the difference between the training input 𝒙ϵ=αϵ𝒙0+σϵϵsubscript𝒙italic-ϵsubscript𝛼italic-ϵsubscript𝒙0subscript𝜎italic-ϵbold-italic-ϵ\bm{x}_{\epsilon}=\alpha_{\epsilon}\bm{x}_{0}+\sigma_{\epsilon}\bm{\epsilon}bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT = italic_α start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT bold_italic_ϵ (where ϵ𝒩(𝟎,𝑰)similar-tobold-italic-ϵ𝒩0𝑰\bm{\epsilon}\sim\mathcal{N}(\bm{0},\bm{I})bold_italic_ϵ ∼ caligraphic_N ( bold_0 , bold_italic_I )) and the testing input 𝒙0+𝒖subscript𝒙0𝒖\bm{x}_{0}+\bm{u}bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + bold_italic_u. To fill such a gap, we can choose a dequantization distribution q(𝒖|𝒙0)𝑞conditional𝒖subscript𝒙0q(\bm{u}|\bm{x}_{0})italic_q ( bold_italic_u | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) which satisfies

𝒙0+𝒖αϵ𝒙0+σϵϵ,𝒖[1256,1256)d.formulae-sequencesubscript𝒙0𝒖subscript𝛼italic-ϵsubscript𝒙0subscript𝜎italic-ϵbold-italic-ϵ𝒖superscript12561256𝑑\bm{x}_{0}+\bm{u}\approx\alpha_{\epsilon}\bm{x}_{0}+\sigma_{\epsilon}\bm{% \epsilon},\quad\bm{u}\in\left[-\frac{1}{256},\frac{1}{256}\right)^{d}.bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + bold_italic_u ≈ italic_α start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT bold_italic_ϵ , bold_italic_u ∈ [ - divide start_ARG 1 end_ARG start_ARG 256 end_ARG , divide start_ARG 1 end_ARG start_ARG 256 end_ARG ) start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT . (9)

For small enough ϵitalic-ϵ\epsilonitalic_ϵ, we have αϵ1subscript𝛼italic-ϵ1\alpha_{\epsilon}\approx 1italic_α start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ≈ 1, then Eqn. (9) becomes 𝒖σϵαϵϵ𝒖subscript𝜎italic-ϵsubscript𝛼italic-ϵbold-italic-ϵ\bm{u}\approx\frac{\sigma_{\epsilon}}{\alpha_{\epsilon}}\bm{\epsilon}bold_italic_u ≈ divide start_ARG italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT end_ARG bold_italic_ϵ. We also need to ensure the support of q(𝒖|𝒙0)𝑞conditional𝒖subscript𝒙0q(\bm{u}|\bm{x}_{0})italic_q ( bold_italic_u | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) is [1256,1256)dsuperscript12561256𝑑[-\frac{1}{256},\frac{1}{256})^{d}[ - divide start_ARG 1 end_ARG start_ARG 256 end_ARG , divide start_ARG 1 end_ARG start_ARG 256 end_ARG ) start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT, i.e. the random variable σϵαϵϵsubscript𝜎italic-ϵsubscript𝛼italic-ϵbold-italic-ϵ\frac{\sigma_{\epsilon}}{\alpha_{\epsilon}}\bm{\epsilon}divide start_ARG italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT end_ARG bold_italic_ϵ is approximately within [1256,1256)dsuperscript12561256𝑑[-\frac{1}{256},\frac{1}{256})^{d}[ - divide start_ARG 1 end_ARG start_ARG 256 end_ARG , divide start_ARG 1 end_ARG start_ARG 256 end_ARG ) start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT. To this end, we choose the variational dequantization distribution by a truncated normal distribution as follows:

q(𝒖|𝒙0)=𝒯𝒩(𝟎,σϵ2αϵ2𝑰,1256,1256)𝑞conditional𝒖subscript𝒙0𝒯𝒩0superscriptsubscript𝜎italic-ϵ2superscriptsubscript𝛼italic-ϵ2𝑰12561256q(\bm{u}|\bm{x}_{0})=\mathcal{T}\mathcal{N}(\mathbf{0},\frac{\sigma_{\epsilon}% ^{2}}{\alpha_{\epsilon}^{2}}\bm{I},-\frac{1}{256},\frac{1}{256})italic_q ( bold_italic_u | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = caligraphic_T caligraphic_N ( bold_0 , divide start_ARG italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_italic_I , - divide start_ARG 1 end_ARG start_ARG 256 end_ARG , divide start_ARG 1 end_ARG start_ARG 256 end_ARG ) (10)

where 𝒯𝒩(𝒙|𝝁,σ2𝑰,a,b)𝒯𝒩conditional𝒙𝝁superscript𝜎2𝑰𝑎𝑏\mathcal{TN}(\bm{x}|\bm{\mu},\sigma^{2}\bm{I},a,b)caligraphic_T caligraphic_N ( bold_italic_x | bold_italic_μ , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I , italic_a , italic_b ) is a truncated-normal distribution with mean 𝝁𝝁\bm{\mu}bold_italic_μ, covariance σ2𝑰superscript𝜎2𝑰\sigma^{2}\bm{I}italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I, and bounds [a,b]𝑎𝑏[a,b][ italic_a , italic_b ] in each dimension. Moreover, such truncated-normal dequantization provides a guideline for choosing the start time ϵitalic-ϵ\epsilonitalic_ϵ: To avoid large deviation from the truncation by 12561256\frac{1}{256}divide start_ARG 1 end_ARG start_ARG 256 end_ARG, we need to ensure that αϵσϵ𝒖ϵsubscript𝛼italic-ϵsubscript𝜎italic-ϵ𝒖bold-italic-ϵ\frac{\alpha_{\epsilon}}{\sigma_{\epsilon}}\bm{u}\approx\bm{\epsilon}divide start_ARG italic_α start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT end_ARG bold_italic_u ≈ bold_italic_ϵ in most cases. We leverage the 3-σ𝜎\sigmaitalic_σ principle for standard normal distribution and let ϵitalic-ϵ\epsilonitalic_ϵ to satisfy αϵσϵ𝒖[3,3]dsubscript𝛼italic-ϵsubscript𝜎italic-ϵ𝒖superscript33𝑑\frac{\alpha_{\epsilon}}{\sigma_{\epsilon}}\bm{u}\in[-3,3]^{d}divide start_ARG italic_α start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT end_ARG bold_italic_u ∈ [ - 3 , 3 ] start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT. As 𝒖[1256,1256)𝒖12561256\bm{u}\in[-\frac{1}{256},\frac{1}{256})bold_italic_u ∈ [ - divide start_ARG 1 end_ARG start_ARG 256 end_ARG , divide start_ARG 1 end_ARG start_ARG 256 end_ARG ), the critical start time ϵitalic-ϵ\epsilonitalic_ϵ satisfies that the negative log-SNR γϵ=logαϵ2σϵ213.3subscript𝛾italic-ϵsuperscriptsubscript𝛼italic-ϵ2superscriptsubscript𝜎italic-ϵ213.3\gamma_{\epsilon}=-\log\frac{\alpha_{\epsilon}^{2}}{\sigma_{\epsilon}^{2}}% \approx-13.3italic_γ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT = - roman_log divide start_ARG italic_α start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ≈ - 13.3. Surprisingly, such choice of γϵsubscript𝛾italic-ϵ\gamma_{\epsilon}italic_γ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT is exactly the same as the γminsubscript𝛾min\gamma_{\text{min}}italic_γ start_POSTSUBSCRIPT min end_POSTSUBSCRIPT in Kingma et al. (2021) which instead is obtained by training. Such dequantization distribution can ensure the conditions in Eqn. (9) and we validate in Sec. 6 that such dequantization can provide a tighter variational bound yet with no additional training costs. We summarize the likelihood evaluation by such dequantization distribution in the following theorem.

Theorem 3.1 (Variational Bound under Truncated-Normal Dequantization).

Suppose we use the truncated-normal dequantization in Eqn. (10), then the discrete model distribution has the following variational bound:

logP0(𝒙0)subscript𝑃0subscript𝒙0\displaystyle\log P_{0}(\bm{x}_{0})roman_log italic_P start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) 𝔼q(ϵ^)[logpϵ(𝒙^ϵ)]+d2(1+log(2πσϵ2))absentsubscript𝔼𝑞^bold-italic-ϵdelimited-[]subscript𝑝italic-ϵsubscript^𝒙italic-ϵ𝑑212𝜋superscriptsubscript𝜎italic-ϵ2\displaystyle\geq\mathbb{E}_{q(\hat{\bm{\epsilon}})}\left[\log p_{\epsilon}(% \hat{\bm{x}}_{\epsilon})\right]+\frac{d}{2}(1+\log(2\pi\sigma_{\epsilon}^{2}))≥ blackboard_E start_POSTSUBSCRIPT italic_q ( over^ start_ARG bold_italic_ϵ end_ARG ) end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) ] + divide start_ARG italic_d end_ARG start_ARG 2 end_ARG ( 1 + roman_log ( 2 italic_π italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) )
+dlogZdτ2πZexp(12τ2)𝑑𝑍𝑑𝜏2𝜋𝑍12superscript𝜏2\displaystyle+d\log Z-d\frac{\tau}{\sqrt{2\pi}Z}\exp(-\frac{1}{2}\tau^{2})+ italic_d roman_log italic_Z - italic_d divide start_ARG italic_τ end_ARG start_ARG square-root start_ARG 2 italic_π end_ARG italic_Z end_ARG roman_exp ( - divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_τ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )

where

τ𝜏\displaystyle\tauitalic_τ =αϵ256σϵ,Z=𝑒𝑟𝑓(τ2)formulae-sequenceabsentsubscript𝛼italic-ϵ256subscript𝜎italic-ϵ𝑍𝑒𝑟𝑓𝜏2\displaystyle=\frac{\alpha_{\epsilon}}{256\sigma_{\epsilon}},\quad Z=\text{erf% }\left(\frac{\tau}{\sqrt{2}}\right)= divide start_ARG italic_α start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT end_ARG start_ARG 256 italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT end_ARG , italic_Z = erf ( divide start_ARG italic_τ end_ARG start_ARG square-root start_ARG 2 end_ARG end_ARG )
𝒙^ϵsubscript^𝒙italic-ϵ\displaystyle\hat{\bm{x}}_{\epsilon}over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT =αϵ𝒙0+σϵϵ^,ϵ^𝒯𝒩(ϵ^|𝟎,𝑰,τ,τ).formulae-sequenceabsentsubscript𝛼italic-ϵsubscript𝒙0subscript𝜎italic-ϵ^bold-italic-ϵsimilar-to^bold-italic-ϵ𝒯𝒩conditional^bold-italic-ϵ0𝑰𝜏𝜏\displaystyle=\alpha_{\epsilon}\bm{x}_{0}+\sigma_{\epsilon}\hat{\bm{\epsilon}}% ,\quad\hat{\bm{\epsilon}}\sim\mathcal{T}\mathcal{N}\left(\hat{\bm{\epsilon}}% \left|\bm{0},\bm{I},-\tau,\tau\right.\right).= italic_α start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT over^ start_ARG bold_italic_ϵ end_ARG , over^ start_ARG bold_italic_ϵ end_ARG ∼ caligraphic_T caligraphic_N ( over^ start_ARG bold_italic_ϵ end_ARG | bold_0 , bold_italic_I , - italic_τ , italic_τ ) .

Besides, we also have the following importance-weighted likelihood estimator by using K𝐾Kitalic_K i.i.d. samples by using Jensen’s inequality as in Burda et al. (2015). As K𝐾Kitalic_K increases, the estimator gives a tighter bound, which enables more precise likelihood estimation.

Corollary 3.2 (Importance Weighted Variational Bound under Truncated-Normal Dequantization).

Suppose we use the truncated-normal dequantization in Eqn. (10), then the discrete model distribution has the following importance-weighted variational bound:

logP0(𝒙0)𝔼i=1Kq(ϵ^(i))[log(1Ki=1Kpϵ(𝒙^ϵ(i))q(ϵ^(i)))]+dlogσϵsubscript𝑃0subscript𝒙0subscript𝔼superscriptsubscriptproduct𝑖1𝐾𝑞superscript^bold-italic-ϵ𝑖delimited-[]1𝐾superscriptsubscript𝑖1𝐾subscript𝑝italic-ϵsuperscriptsubscript^𝒙italic-ϵ𝑖𝑞superscript^bold-italic-ϵ𝑖𝑑subscript𝜎italic-ϵ\log\!P_{0}(\bm{x}_{0})\!\geq\!\mathbb{E}_{\prod_{i=1}^{K}q(\hat{\bm{\epsilon}% }^{(i)})}\!\!\left[\log\!\left(\!\!\frac{1}{K}\!\sum_{i=1}^{K}\frac{p_{% \epsilon}(\hat{\bm{x}}_{\epsilon}^{(i)})}{q(\hat{\bm{\epsilon}}^{(i)})}\!\!% \right)\!\right]+d\log\sigma_{\epsilon}roman_log italic_P start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ≥ blackboard_E start_POSTSUBSCRIPT ∏ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_q ( over^ start_ARG bold_italic_ϵ end_ARG start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT [ roman_log ( divide start_ARG 1 end_ARG start_ARG italic_K end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT divide start_ARG italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) end_ARG start_ARG italic_q ( over^ start_ARG bold_italic_ϵ end_ARG start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) end_ARG ) ] + italic_d roman_log italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT

where

𝒙^ϵ(i)superscriptsubscript^𝒙italic-ϵ𝑖\displaystyle\hat{\bm{x}}_{\epsilon}^{(i)}over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT =αϵ𝒙0+σϵϵ^(i),ϵ^(i)𝒯𝒩(ϵ^(i)|𝟎,𝑰,τ,τ)formulae-sequenceabsentsubscript𝛼italic-ϵsubscript𝒙0subscript𝜎italic-ϵsuperscript^bold-italic-ϵ𝑖similar-tosuperscript^bold-italic-ϵ𝑖𝒯𝒩conditionalsuperscript^bold-italic-ϵ𝑖0𝑰𝜏𝜏\displaystyle=\alpha_{\epsilon}\bm{x}_{0}+\sigma_{\epsilon}\hat{\bm{\epsilon}}% ^{(i)},\quad\hat{\bm{\epsilon}}^{(i)}\sim\mathcal{T}\mathcal{N}\left(\hat{\bm{% \epsilon}}^{(i)}\left|\bm{0},\bm{I},-\tau,\tau\right.\right)= italic_α start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT over^ start_ARG bold_italic_ϵ end_ARG start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT , over^ start_ARG bold_italic_ϵ end_ARG start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ∼ caligraphic_T caligraphic_N ( over^ start_ARG bold_italic_ϵ end_ARG start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT | bold_0 , bold_italic_I , - italic_τ , italic_τ )
q(ϵ^)𝑞^bold-italic-ϵ\displaystyle q(\hat{\bm{\epsilon}})italic_q ( over^ start_ARG bold_italic_ϵ end_ARG ) =1(2πZ2)d2exp(12ϵ^22),Z=𝑒𝑟𝑓(τ2).formulae-sequenceabsent1superscript2𝜋superscript𝑍2𝑑212superscriptsubscriptnorm^bold-italic-ϵ22𝑍𝑒𝑟𝑓𝜏2\displaystyle=\frac{1}{(2\pi Z^{2})^{\frac{d}{2}}}\exp(-\frac{1}{2}\|\hat{\bm{% \epsilon}}\|_{2}^{2}),\quad Z=\text{erf}\left(\frac{\tau}{\sqrt{2}}\right).= divide start_ARG 1 end_ARG start_ARG ( 2 italic_π italic_Z start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT divide start_ARG italic_d end_ARG start_ARG 2 end_ARG end_POSTSUPERSCRIPT end_ARG roman_exp ( - divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∥ over^ start_ARG bold_italic_ϵ end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) , italic_Z = erf ( divide start_ARG italic_τ end_ARG start_ARG square-root start_ARG 2 end_ARG end_ARG ) .
Remark 3.3.

Another way to bridge the discrete-continuous gap is variational perspective. We can view the process from discrete 𝒙0subscript𝒙0\bm{x}_{0}bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT to continuous 𝒙ϵsubscript𝒙italic-ϵ\bm{x}_{\epsilon}bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT as a variational autoencoder, where the prior pϵ(𝒙ϵ)subscript𝑝italic-ϵsubscript𝒙italic-ϵp_{\epsilon}(\bm{x}_{\epsilon})italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) is modeled by diffusion ODE. The dequantization and variational perspectives of diffusion ODEs have a close relationship both theoretically and empirically, and we detailedly discuss them in Appendix A.

4 Practical Techniques for Improving the Likelihood of Diffusion ODEs

In this section, we propose some practical techniques for improving the likelihood of diffusion ODEs, including parameterization, a high-order training objective, and variance reduction by importance sampling. For simplicity, we denote fx˙=df(x)dx˙subscript𝑓𝑥d𝑓𝑥d𝑥\dot{f_{x}}=\frac{\mathrm{d}f(x)}{\mathrm{d}x}over˙ start_ARG italic_f start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT end_ARG = divide start_ARG roman_d italic_f ( italic_x ) end_ARG start_ARG roman_d italic_x end_ARG for any scalar function f(x)𝑓𝑥f(x)italic_f ( italic_x ).

4.1 Velocity Parameterization

While the score matching objective 𝒥SM(θ)subscript𝒥SM𝜃\mathcal{J}_{\text{SM}}(\theta)caligraphic_J start_POSTSUBSCRIPT SM end_POSTSUBSCRIPT ( italic_θ ) only depends on the noise schedule, the training process is affected by many aspects such as network parameterization (Song et al., 2021c; Karras et al., 2022). For example, the noise predictor ϵθ(𝒙t,t)subscriptbold-italic-ϵ𝜃subscript𝒙𝑡𝑡\bm{\epsilon}_{\theta}(\bm{x}_{t},t)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) is widely used to replace the score predictor 𝒔θ(𝒙t,t)subscript𝒔𝜃subscript𝒙𝑡𝑡\bm{s}_{\theta}(\bm{x}_{t},t)bold_italic_s start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ), since the noise ϵ𝒩(𝟎,𝑰)similar-tobold-italic-ϵ𝒩0𝑰\bm{\epsilon}\sim\mathcal{N}(\mathbf{0},\bm{I})bold_italic_ϵ ∼ caligraphic_N ( bold_0 , bold_italic_I ) has unit variance and is easier to fit, while 𝒔θ(𝒙t,t)=ϵθ(𝒙t,t)/σtsubscript𝒔𝜃subscript𝒙𝑡𝑡subscriptbold-italic-ϵ𝜃subscript𝒙𝑡𝑡subscript𝜎𝑡\bm{s}_{\theta}(\bm{x}_{t},t)=-\bm{\epsilon}_{\theta}(\bm{x}_{t},t)/\sigma_{t}bold_italic_s start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) = - bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) / italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is pathological and explosive near t=0𝑡0t=0italic_t = 0 (Song et al., 2021c).

In this work, we consider another network parameterization which is to directly predict the drift of the diffusion ODE. The parameterized model is defined by

d𝒙tdt=𝒗θ(𝒙t,t)f(t)𝒙t12g2(t)𝒔θ(𝒙t,t)dsubscript𝒙𝑡d𝑡subscript𝒗𝜃subscript𝒙𝑡𝑡𝑓𝑡subscript𝒙𝑡12superscript𝑔2𝑡subscript𝒔𝜃subscript𝒙𝑡𝑡\frac{\mathrm{d}\bm{x}_{t}}{\mathrm{d}t}=\bm{v}_{\theta}(\bm{x}_{t},t)% \coloneqq f(t)\bm{x}_{t}-\frac{1}{2}g^{2}(t)\bm{s}_{\theta}(\bm{x}_{t},t)divide start_ARG roman_d bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG roman_d italic_t end_ARG = bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) ≔ italic_f ( italic_t ) bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) bold_italic_s start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) (11)

By rewriting the (first-order) score matching objective in Eqn. (5), 𝒥SM(θ)subscript𝒥SM𝜃\mathcal{J}_{\text{SM}}(\theta)caligraphic_J start_POSTSUBSCRIPT SM end_POSTSUBSCRIPT ( italic_θ ) is equivalent to:

𝒥FM(θ)0T2g2(t)𝔼𝒙0,ϵ[𝒗θ(𝒙t,t)𝒗22]dt,subscript𝒥FM𝜃superscriptsubscript0𝑇2superscript𝑔2𝑡subscript𝔼subscript𝒙0bold-italic-ϵdelimited-[]superscriptsubscriptnormsubscript𝒗𝜃subscript𝒙𝑡𝑡𝒗22differential-d𝑡\mathcal{J}_{\text{FM}}(\theta)\coloneqq\int_{0}^{T}\frac{2}{g^{2}(t)}\mathbb{% E}_{\bm{x}_{0},\bm{\epsilon}}\left[\|\bm{v}_{\theta}(\bm{x}_{t},t)-\bm{v}\|_{2% }^{2}\right]\mathrm{d}t,caligraphic_J start_POSTSUBSCRIPT FM end_POSTSUBSCRIPT ( italic_θ ) ≔ ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT divide start_ARG 2 end_ARG start_ARG italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) end_ARG blackboard_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ end_POSTSUBSCRIPT [ ∥ bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - bold_italic_v ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] roman_d italic_t , (12)

where 𝒗=α˙t𝒙0+σ˙tϵ𝒗subscript˙𝛼𝑡subscript𝒙0subscript˙𝜎𝑡bold-italic-ϵ\bm{v}=\dot{\alpha}_{t}\bm{x}_{0}+\dot{\sigma}_{t}\bm{\epsilon}bold_italic_v = over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_ϵ is the velocity to predict. Given unlimited model capacity, the optimal 𝒗*superscript𝒗\bm{v}^{*}bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT is

𝒗*(𝒙t,t)=f(t)𝒙t12g2(t)𝒙logqt(𝒙t),superscript𝒗subscript𝒙𝑡𝑡𝑓𝑡subscript𝒙𝑡12superscript𝑔2𝑡subscript𝒙subscript𝑞𝑡subscript𝒙𝑡\bm{v}^{*}(\bm{x}_{t},t)=f(t)\bm{x}_{t}-\frac{1}{2}g^{2}(t)\nabla_{\bm{x}}\log q% _{t}(\bm{x}_{t}),bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) = italic_f ( italic_t ) bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) ∇ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT roman_log italic_q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) , (13)

which is the drift of probability flow ODE in Eqn. (2).

We give an intuitive explanation for 𝒥FMsubscript𝒥FM\mathcal{J}_{\text{FM}}caligraphic_J start_POSTSUBSCRIPT FM end_POSTSUBSCRIPT in Appendix D that the prediction target 𝒗𝒗\bm{v}bold_italic_v is the tangent (velocity) of the diffusion path, and we name 𝒗θsubscript𝒗𝜃\bm{v}_{\theta}bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT as velocity parameterization. Besides, we show it empirically alleviates the imbalance problem in noise prediction.

In addition, we prove the equivalence between different predictors and different matching objectives for general noise schedules in Appendix B. We also show in Appendix E that the flow matching method Lipman et al. (2022); Albergo & Vanden-Eijnden (2022); Liu et al. (2022b) and related techniques for improving the sample quality of diffusion models in Karras et al. (2022); Salimans & Ho (2022); Ho et al. (2022) can all be reformulated in velocity parameterization. To be consistent, we still call 𝒥FMsubscript𝒥FM\mathcal{J}_{\text{FM}}caligraphic_J start_POSTSUBSCRIPT FM end_POSTSUBSCRIPT as flow matching. It’s an extended version of Lipman et al. (2022) with likelihood weighting and several practical modifications as detailed in Section 4.3.

4.2 Error-bounded Second-Order Flow Matching

According to Chen et al. (2018a), the ODE likelihood of Eqn. (11) can be evaluated by solving the following differential equation from ϵitalic-ϵ\epsilonitalic_ϵ to T𝑇Titalic_T:

dlogpt(𝒙t)dt=tr(𝒙𝒗θ(𝒙t,t)).dsubscript𝑝𝑡subscript𝒙𝑡d𝑡trsubscript𝒙subscript𝒗𝜃subscript𝒙𝑡𝑡\frac{\mathrm{d}\log p_{t}(\bm{x}_{t})}{\mathrm{d}t}=-\mathrm{tr}(\nabla_{\bm{% x}}\bm{v}_{\theta}(\bm{x}_{t},t)).divide start_ARG roman_d roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG start_ARG roman_d italic_t end_ARG = - roman_tr ( ∇ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) ) . (14)

As 𝒥FMsubscript𝒥FM\mathcal{J}_{\text{FM}}caligraphic_J start_POSTSUBSCRIPT FM end_POSTSUBSCRIPT in Eqn. (12) can only restrict the distance between 𝒗θsubscript𝒗𝜃\bm{v}_{\theta}bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT and 𝒗*superscript𝒗\bm{v}^{*}bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT, but not the divergence tr(𝒙𝒗θ)trsubscript𝒙subscript𝒗𝜃\mathrm{tr}(\nabla_{\bm{x}}\bm{v}_{\theta})roman_tr ( ∇ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ) and tr(𝒙𝒗*)trsubscript𝒙superscript𝒗\mathrm{tr}(\nabla_{\bm{x}}\bm{v}^{*})roman_tr ( ∇ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ). The precision and smoothness of the trace tr(𝒙𝒗θ(𝒙t,t))trsubscript𝒙subscript𝒗𝜃subscript𝒙𝑡𝑡\mathrm{tr}(\nabla_{\bm{x}}\bm{v}_{\theta}(\bm{x}_{t},t))roman_tr ( ∇ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) ) affects the likelihood performance and the number of function evaluations for sampling. For simulation-free training of tr(𝒙𝒗θ(𝒙t,t))trsubscript𝒙subscript𝒗𝜃subscript𝒙𝑡𝑡\mathrm{tr}(\nabla_{\bm{x}}\bm{v}_{\theta}(\bm{x}_{t},t))roman_tr ( ∇ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) ), we propose an error-bounded trace of second-order flow matching, where the second-order error is bounded by the proposed objective and first-order error.

Theorem 4.1.

(Error-Bounded Trace of Second-Order Flow Matching) Suppose we have a first-order velocity estimator 𝐯^1(𝐱t,t)subscriptnormal-^𝐯1subscript𝐱𝑡𝑡\hat{\bm{v}}_{1}(\bm{x}_{t},t)over^ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ), we can learn a second-order trace velocity model 𝐯2𝑡𝑟𝑎𝑐𝑒(,t;θ):dnormal-:superscriptsubscript𝐯2𝑡𝑟𝑎𝑐𝑒normal-⋅𝑡𝜃normal-→superscript𝑑\bm{v}_{2}^{\text{trace}}(\cdot,t;\theta):\mathbb{R}^{d}\rightarrow\mathbb{R}bold_italic_v start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT trace end_POSTSUPERSCRIPT ( ⋅ , italic_t ; italic_θ ) : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT → blackboard_R which minimizes

𝔼qt(𝒙t)[|𝒗2𝑡𝑟𝑎𝑐𝑒(𝒙t,t;θ)tr(𝒙𝒗*(𝒙t,t))|2],subscript𝔼subscript𝑞𝑡subscript𝒙𝑡delimited-[]superscriptsuperscriptsubscript𝒗2𝑡𝑟𝑎𝑐𝑒subscript𝒙𝑡𝑡𝜃trsubscript𝒙superscript𝒗subscript𝒙𝑡𝑡2\mathbb{E}_{q_{t}(\bm{x}_{t})}\left[\left|\bm{v}_{2}^{\text{trace}}(\bm{x}_{t}% ,t;\theta)-\mathrm{tr}(\nabla_{\bm{x}}\bm{v}^{*}(\bm{x}_{t},t))\right|^{2}% \right],blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ | bold_italic_v start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT trace end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ; italic_θ ) - roman_tr ( ∇ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) ) | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] ,

by optimizing

θ*=argminθ𝔼𝒙0,ϵ[|𝒗2𝑡𝑟𝑎𝑐𝑒(𝒙t,t;θ)σ˙tσtd+1|2].superscript𝜃subscriptargmin𝜃subscript𝔼subscript𝒙0bold-italic-ϵdelimited-[]superscriptsuperscriptsubscript𝒗2𝑡𝑟𝑎𝑐𝑒subscript𝒙𝑡𝑡𝜃subscript˙𝜎𝑡subscript𝜎𝑡𝑑subscriptbold-ℓ12\theta^{*}\!=\operatornamewithlimits{argmin}_{\theta}\mathbb{E}_{\bm{x}_{0},% \bm{\epsilon}}\!\left[\left|\bm{v}_{2}^{\text{trace}}(\bm{x}_{t},t;\theta)\!-% \!\frac{\dot{\sigma}_{t}}{\sigma_{t}}d\!+\!\bm{\ell}_{1}\right|^{2}\right].italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT = roman_argmin start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ end_POSTSUBSCRIPT [ | bold_italic_v start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT trace end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ; italic_θ ) - divide start_ARG over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG italic_d + bold_ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] . (15)

where

1(ϵ,𝒙0,t)2g2(t)𝒗^1(𝒙t,t)𝒗22,subscriptbold-ℓ1bold-italic-ϵsubscript𝒙0𝑡2superscript𝑔2𝑡superscriptsubscriptnormsubscript^𝒗1subscript𝒙𝑡𝑡𝒗22\displaystyle\bm{\ell}_{1}(\bm{\epsilon},\bm{x}_{0},t)\coloneqq\frac{2}{g^{2}(% t)}\|\hat{\bm{v}}_{1}(\bm{x}_{t},t)-\bm{v}\|_{2}^{2},bold_ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( bold_italic_ϵ , bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_t ) ≔ divide start_ARG 2 end_ARG start_ARG italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) end_ARG ∥ over^ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - bold_italic_v ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ,
𝒙t=αt𝒙0+σtϵ,𝒗=α˙t𝒙0+σ˙tϵ,ϵ𝒩(𝟎,𝑰).formulae-sequencesubscript𝒙𝑡subscript𝛼𝑡subscript𝒙0subscript𝜎𝑡bold-italic-ϵformulae-sequence𝒗subscript˙𝛼𝑡subscript𝒙0subscript˙𝜎𝑡bold-italic-ϵsimilar-tobold-italic-ϵ𝒩0𝑰\displaystyle\bm{x}_{t}=\alpha_{t}\bm{x}_{0}+\sigma_{t}\bm{\epsilon},\quad\bm{% v}=\dot{\alpha}_{t}\bm{x}_{0}+\dot{\sigma}_{t}\bm{\epsilon},\quad\bm{\epsilon}% \sim\mathcal{N}(\bm{0},\bm{I}).bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_ϵ , bold_italic_v = over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_ϵ , bold_italic_ϵ ∼ caligraphic_N ( bold_0 , bold_italic_I ) .

Moreover, denote the first-order flow matching error as δ1(𝐱t,t)𝐯^1(𝐱t,t)𝐯*(𝐱t,t)2normal-≔subscript𝛿1subscript𝐱𝑡𝑡subscriptnormsubscriptnormal-^𝐯1subscript𝐱𝑡𝑡superscript𝐯subscript𝐱𝑡𝑡2\delta_{1}(\bm{x}_{t},t)\coloneqq\|\hat{\bm{v}}_{1}(\bm{x}_{t},t)-\bm{v}^{*}(% \bm{x}_{t},t)\|_{2}italic_δ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) ≔ ∥ over^ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, then 𝐱t,θfor-allsubscript𝐱𝑡𝜃\forall\bm{x}_{t},\theta∀ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_θ, the estimation error for 𝐯2𝑡𝑟𝑎𝑐𝑒(𝐱t,t;θ)superscriptsubscript𝐯2𝑡𝑟𝑎𝑐𝑒subscript𝐱𝑡𝑡𝜃\bm{v}_{2}^{\text{trace}}(\bm{x}_{t},t;\theta)bold_italic_v start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT trace end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ; italic_θ ) can be bounded by:

|𝒗2𝑡𝑟𝑎𝑐𝑒(𝒙t,t;θ)tr(𝒙𝒗*(𝒙t,t))||𝒗2𝑡𝑟𝑎𝑐𝑒(𝒙t,t;θ)𝒗2𝑡𝑟𝑎𝑐𝑒(𝒙t,t;θ*)|+2g2(t)δ12(𝒙t,t).superscriptsubscript𝒗2𝑡𝑟𝑎𝑐𝑒subscript𝒙𝑡𝑡𝜃trsubscript𝒙superscript𝒗subscript𝒙𝑡𝑡superscriptsubscript𝒗2𝑡𝑟𝑎𝑐𝑒subscript𝒙𝑡𝑡𝜃superscriptsubscript𝒗2𝑡𝑟𝑎𝑐𝑒subscript𝒙𝑡𝑡superscript𝜃2superscript𝑔2𝑡superscriptsubscript𝛿12subscript𝒙𝑡𝑡\begin{split}&\left|\bm{v}_{2}^{\text{trace}}(\bm{x}_{t},t;\theta)-\mathrm{tr}% (\nabla_{\bm{x}}\bm{v}^{*}(\bm{x}_{t},t))\right|\\ \leq\ &\left|\bm{v}_{2}^{\text{trace}}(\bm{x}_{t},t;\theta)-\bm{v}_{2}^{\text{% trace}}(\bm{x}_{t},t;\theta^{*})\right|+\frac{2}{g^{2}(t)}\delta_{1}^{2}(\bm{x% }_{t},t).\end{split}start_ROW start_CELL end_CELL start_CELL | bold_italic_v start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT trace end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ; italic_θ ) - roman_tr ( ∇ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) ) | end_CELL end_ROW start_ROW start_CELL ≤ end_CELL start_CELL | bold_italic_v start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT trace end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ; italic_θ ) - bold_italic_v start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT trace end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ; italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) | + divide start_ARG 2 end_ARG start_ARG italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) end_ARG italic_δ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) . end_CELL end_ROW

The proof is provided in Appendix F. In practice, we choose 𝒗2trace(𝒙t,t;θ)=tr(𝒙𝒗θ(𝒙t,t))superscriptsubscript𝒗2tracesubscript𝒙𝑡𝑡𝜃trsubscript𝒙subscript𝒗𝜃subscript𝒙𝑡𝑡\bm{v}_{2}^{\text{trace}}(\bm{x}_{t},t;\theta)=\mathrm{tr}(\nabla_{\bm{x}}\bm{% v}_{\theta}(\bm{x}_{t},t))bold_italic_v start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT trace end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ; italic_θ ) = roman_tr ( ∇ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) ) for self-regularizing. As for scalability, we use Hutchinson’s trace estimator (Hutchinson, 1990) to unbiasedly estimate the trace, and use forward-mode automatic differentiation to compute Jacobian-vector product (Lu et al., 2022a).

4.3 Timing by Log-SNR and Normalizing Velocity

In practice, we make two modifications to improve the performance. First, we use negative log-SNR γtsubscript𝛾𝑡\gamma_{t}italic_γ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT to time the diffusion process. Still, we parameterize 𝒗θ(𝒙γ,γ)subscript𝒗𝜃subscript𝒙𝛾𝛾\bm{v}_{\theta}(\bm{x}_{\gamma},\gamma)bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ) to predict the drift of the γ𝛾\gammaitalic_γ timed diffusion ODE i.e. d𝒙γdγ=𝒗θ(𝒙γ,γ)dsubscript𝒙𝛾d𝛾subscript𝒗𝜃subscript𝒙𝛾𝛾\frac{\mathrm{d}\bm{x}_{\gamma}}{\mathrm{d}\gamma}=\bm{v}_{\theta}(\bm{x}_{% \gamma},\gamma)divide start_ARG roman_d bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT end_ARG start_ARG roman_d italic_γ end_ARG = bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ), so the corresponding predictor 𝒗θ(𝒙t,t)=𝒗θ(𝒙γ,γ)dγdtsubscript𝒗𝜃subscript𝒙𝑡𝑡subscript𝒗𝜃subscript𝒙𝛾𝛾d𝛾d𝑡\bm{v}_{\theta}(\bm{x}_{t},t)=\bm{v}_{\theta}(\bm{x}_{\gamma},\gamma)\frac{% \mathrm{d}\gamma}{\mathrm{d}t}bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) = bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ) divide start_ARG roman_d italic_γ end_ARG start_ARG roman_d italic_t end_ARG. Second, the velocity of the diffusion path 𝒗=α˙t𝒙0+σ˙tϵ𝒗subscript˙𝛼𝑡subscript𝒙0subscript˙𝜎𝑡bold-italic-ϵ\bm{v}=\dot{\alpha}_{t}\bm{x}_{0}+\dot{\sigma}_{t}\bm{\epsilon}bold_italic_v = over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_ϵ may have different scales at different t𝑡titalic_t, so we propose to predict the normalized velocity 𝒗~=𝒗/α˙t2+σ˙t2~𝒗𝒗superscriptsubscript˙𝛼𝑡2superscriptsubscript˙𝜎𝑡2\tilde{\bm{v}}=\bm{v}/\sqrt{\dot{\alpha}_{t}^{2}+\dot{\sigma}_{t}^{2}}over~ start_ARG bold_italic_v end_ARG = bold_italic_v / square-root start_ARG over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG, with the parameterized network 𝒗~θ(𝒙t,t)=𝒗θ(𝒙t,t)/α˙t2+σ˙t2subscript~𝒗𝜃subscript𝒙𝑡𝑡subscript𝒗𝜃subscript𝒙𝑡𝑡superscriptsubscript˙𝛼𝑡2superscriptsubscript˙𝜎𝑡2\tilde{\bm{v}}_{\theta}(\bm{x}_{t},t)=\bm{v}_{\theta}(\bm{x}_{t},t)/\sqrt{\dot% {\alpha}_{t}^{2}+\dot{\sigma}_{t}^{2}}over~ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) = bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) / square-root start_ARG over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG, which is equal to 𝒗~θ(𝒙γ,γ)=𝒗θ(𝒙γ,γ)/α˙γ2+σ˙γ2subscript~𝒗𝜃subscript𝒙𝛾𝛾subscript𝒗𝜃subscript𝒙𝛾𝛾superscriptsubscript˙𝛼𝛾2superscriptsubscript˙𝜎𝛾2\tilde{\bm{v}}_{\theta}(\bm{x}_{\gamma},\gamma)=\bm{v}_{\theta}(\bm{x}_{\gamma% },\gamma)/\sqrt{\dot{\alpha}_{\gamma}^{2}+\dot{\sigma}_{\gamma}^{2}}over~ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ) = bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ) / square-root start_ARG over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG. The objective in Eqn. (12) reduces to

𝒥FM(θ)=γ0γT2α˙γ2+σ˙γ2σγ2𝔼𝒙0,ϵ𝒗~θ(𝒙γ,γ)𝒗~22dγ.subscript𝒥FM𝜃superscriptsubscriptsubscript𝛾0subscript𝛾𝑇2superscriptsubscript˙𝛼𝛾2superscriptsubscript˙𝜎𝛾2superscriptsubscript𝜎𝛾2subscript𝔼subscript𝒙0bold-italic-ϵsuperscriptsubscriptnormsubscript~𝒗𝜃subscript𝒙𝛾𝛾~𝒗22differential-d𝛾\mathcal{J}_{\text{FM}}(\theta)\!=\!\int_{\gamma_{0}}^{\gamma_{T}}\!\!2\frac{% \dot{\alpha}_{\gamma}^{2}+\dot{\sigma}_{\gamma}^{2}}{\sigma_{\gamma}^{2}}% \mathbb{E}_{\bm{x}_{0},\bm{\epsilon}}\|\tilde{\bm{v}}_{\theta}(\bm{x}_{\gamma}% ,\gamma)-\tilde{\bm{v}}\|_{2}^{2}\mathrm{d}\gamma.caligraphic_J start_POSTSUBSCRIPT FM end_POSTSUBSCRIPT ( italic_θ ) = ∫ start_POSTSUBSCRIPT italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_γ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT end_POSTSUPERSCRIPT 2 divide start_ARG over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG blackboard_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ end_POSTSUBSCRIPT ∥ over~ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ) - over~ start_ARG bold_italic_v end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_d italic_γ .

And the corresponding second-order objective:

𝒥FM,tr=subscript𝒥FMtrabsent\displaystyle\mathcal{J}_{\text{FM},\mathrm{tr}}\!=caligraphic_J start_POSTSUBSCRIPT FM , roman_tr end_POSTSUBSCRIPT = γ0γT2α˙γ2+σ˙γ2σγ2𝔼𝒙0,ϵ(σγtr(𝒗~θ)σ˙γα˙γ2+σ˙γ2d\displaystyle\!\int_{\gamma_{0}}^{\gamma_{T}}\!\!2\frac{\dot{\alpha}_{\gamma}^% {2}+\dot{\sigma}_{\gamma}^{2}}{\sigma_{\gamma}^{2}}\mathbb{E}_{\bm{x}_{0},\bm{% \epsilon}}\!\Bigg{(}\!\sigma_{\gamma}\mathrm{tr}(\nabla\tilde{\bm{v}}_{\theta}% )\!-\!\frac{\dot{\sigma}_{\gamma}}{\sqrt{\dot{\alpha}_{\gamma}^{2}+\dot{\sigma% }_{\gamma}^{2}}}d∫ start_POSTSUBSCRIPT italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_γ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT end_POSTSUPERSCRIPT 2 divide start_ARG over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG blackboard_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ end_POSTSUBSCRIPT ( italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT roman_tr ( ∇ over~ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ) - divide start_ARG over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT end_ARG start_ARG square-root start_ARG over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG end_ARG italic_d
+2α˙γ2+σ˙γ2σγ𝒗~θ(s)(𝒙γ,γ)𝒗~22)2dγ\displaystyle+\frac{2\sqrt{\dot{\alpha}_{\gamma}^{2}+\dot{\sigma}_{\gamma}^{2}% }}{\sigma_{\gamma}}\|\tilde{\bm{v}}^{(s)}_{\theta}(\bm{x}_{\gamma},\gamma)-% \tilde{\bm{v}}\|_{2}^{2}\Bigg{)}^{2}\mathrm{d}\gamma+ divide start_ARG 2 square-root start_ARG over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT end_ARG ∥ over~ start_ARG bold_italic_v end_ARG start_POSTSUPERSCRIPT ( italic_s ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ) - over~ start_ARG bold_italic_v end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_d italic_γ (16)

where 𝒗~θ(s)subscriptsuperscript~𝒗𝑠𝜃\tilde{\bm{v}}^{(s)}_{\theta}over~ start_ARG bold_italic_v end_ARG start_POSTSUPERSCRIPT ( italic_s ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT is the stop-gradient version of 𝒗~θsubscript~𝒗𝜃\tilde{\bm{v}}_{\theta}over~ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT, since we only use the parameterized first-order velocity predictor as an estimator. Our final formulation of parameterized diffusion ODE is

d𝒙γdγ=α˙γ2+σ˙γ2𝒗~θ(𝒙γ,γ)dsubscript𝒙𝛾d𝛾superscriptsubscript˙𝛼𝛾2superscriptsubscript˙𝜎𝛾2subscript~𝒗𝜃subscript𝒙𝛾𝛾\frac{\mathrm{d}\bm{x}_{\gamma}}{\mathrm{d}\gamma}=\sqrt{\dot{\alpha}_{\gamma}% ^{2}+\dot{\sigma}_{\gamma}^{2}}\tilde{\bm{v}}_{\theta}(\bm{x}_{\gamma},\gamma)divide start_ARG roman_d bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT end_ARG start_ARG roman_d italic_γ end_ARG = square-root start_ARG over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG over~ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ) (17)

4.4 Variance Reduction with Importance Sampling

The flow matching is conducted for all γ𝛾\gammaitalic_γ in [γ0,γT]subscript𝛾0subscript𝛾𝑇[\gamma_{0},\gamma_{T}][ italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_γ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ] through an integral. In practice, the evaluation of the integral is time-consuming, and Monte-Carlo methods are used to unbiasedly estimate the objective by uniformly sampling γ𝛾\gammaitalic_γ. In this case, the variance of the Monte-Carlo estimator affects the optimization process. Thus, a continuous importance distribution p(γ)𝑝𝛾p(\gamma)italic_p ( italic_γ ) can be proposed for variance reduction. Denote θ(𝒙0,ϵ,γ,)=2α˙γ2+σ˙γ2σγ2𝒗~θ(𝒙γ,γ)𝒗~22\mathcal{L}_{\theta}(\bm{x}_{0},\bm{\epsilon},\gamma,)=2\frac{\dot{\alpha}_{% \gamma}^{2}+\dot{\sigma}_{\gamma}^{2}}{\sigma_{\gamma}^{2}}\|\tilde{\bm{v}}_{% \theta}(\bm{x}_{\gamma},\gamma)-\tilde{\bm{v}}\|_{2}^{2}caligraphic_L start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ , italic_γ , ) = 2 divide start_ARG over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∥ over~ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ) - over~ start_ARG bold_italic_v end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT, then

𝒥FM(θ)=𝔼γp(γ)𝔼𝒙0,ϵ[θ(𝒙0,ϵ,γ)p(γ)]subscript𝒥FM𝜃subscript𝔼similar-to𝛾𝑝𝛾subscript𝔼subscript𝒙0bold-italic-ϵdelimited-[]subscript𝜃subscript𝒙0bold-italic-ϵ𝛾𝑝𝛾\mathcal{J}_{\text{FM}}(\theta)=\mathbb{E}_{\gamma\sim p(\gamma)}\mathbb{E}_{% \bm{x}_{0},\bm{\epsilon}}\left[\frac{\mathcal{L}_{\theta}(\bm{x}_{0},\bm{% \epsilon},\gamma)}{p(\gamma)}\right]caligraphic_J start_POSTSUBSCRIPT FM end_POSTSUBSCRIPT ( italic_θ ) = blackboard_E start_POSTSUBSCRIPT italic_γ ∼ italic_p ( italic_γ ) end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ end_POSTSUBSCRIPT [ divide start_ARG caligraphic_L start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ , italic_γ ) end_ARG start_ARG italic_p ( italic_γ ) end_ARG ] (18)

We propose to use two types of importance sampling (IS), and empirically compare them for faster convergence.

Designed IS

Intuitively, we can choose p(γ)α˙γ2+σ˙γ2σγ2proportional-to𝑝𝛾superscriptsubscript˙𝛼𝛾2superscriptsubscript˙𝜎𝛾2superscriptsubscript𝜎𝛾2p(\gamma)\propto\frac{\dot{\alpha}_{\gamma}^{2}+\dot{\sigma}_{\gamma}^{2}}{% \sigma_{\gamma}^{2}}italic_p ( italic_γ ) ∝ divide start_ARG over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG. This way, the coefficients of 𝒗~θ(𝒙γ,γ)𝒗~22superscriptsubscriptnormsubscript~𝒗𝜃subscript𝒙𝛾𝛾~𝒗22\|\tilde{\bm{v}}_{\theta}(\bm{x}_{\gamma},\gamma)-\tilde{\bm{v}}\|_{2}^{2}∥ over~ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ) - over~ start_ARG bold_italic_v end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT is a time-invariant constant, and the velocity matching error is not amplified or shrank at any γ𝛾\gammaitalic_γ. This is similar to the IS in Song et al. (2021b), where the g2(t)/σt2superscript𝑔2𝑡superscriptsubscript𝜎𝑡2g^{2}(t)/\sigma_{t}^{2}italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) / italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT weighting before the noise matching error ϵθ(𝒙t,t)ϵ22superscriptsubscriptnormsubscriptbold-italic-ϵ𝜃subscript𝒙𝑡𝑡bold-italic-ϵ22\|\bm{\epsilon}_{\theta}(\bm{x}_{t},t)-\bm{\epsilon}\|_{2}^{2}∥ bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - bold_italic_ϵ ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT is cancelled, and it corresponds to uniform γ𝛾\gammaitalic_γ under our parameterization.

For noise schedules used in this paper, we can obtain closed-form sampling procedures using inverse transform sampling, see Appendix C.

Learned IS

The variance of the Monte-Carlo estimator depends on the learned network 𝒗~θsubscript~𝒗𝜃\tilde{\bm{v}}_{\theta}over~ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT. To minimize the variance, we can parameterize the IS with another network and treat the variance as an objective. Actually, learning p(γ)𝑝𝛾p(\gamma)italic_p ( italic_γ ) is equivalent to learning a monotone mapping γ(t):[0,1][γ0,γT]:𝛾𝑡01subscript𝛾0subscript𝛾𝑇\gamma(t):[0,1]\rightarrow[\gamma_{0},\gamma_{T}]italic_γ ( italic_t ) : [ 0 , 1 ] → [ italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_γ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ], which is inverse cumulative distribution function of p(γ)𝑝𝛾p(\gamma)italic_p ( italic_γ ). We can uniformly sample t𝑡titalic_t, and regard the IS as change-of-variable from γ𝛾\gammaitalic_γ to t𝑡titalic_t.

𝒥FM(θ)=𝔼t𝒰(0,1)𝔼𝒙0,ϵ[γ(t)θ(𝒙0,ϵ,γ(t))]subscript𝒥FM𝜃subscript𝔼similar-to𝑡𝒰01subscript𝔼subscript𝒙0bold-italic-ϵdelimited-[]superscript𝛾𝑡subscript𝜃subscript𝒙0bold-italic-ϵ𝛾𝑡\mathcal{J}_{\text{FM}}(\theta)=\mathbb{E}_{t\sim\mathcal{U}(0,1)}\mathbb{E}_{% \bm{x}_{0},\bm{\epsilon}}\left[\gamma^{\prime}(t)\mathcal{L}_{\theta}(\bm{x}_{% 0},\bm{\epsilon},\gamma(t))\right]caligraphic_J start_POSTSUBSCRIPT FM end_POSTSUBSCRIPT ( italic_θ ) = blackboard_E start_POSTSUBSCRIPT italic_t ∼ caligraphic_U ( 0 , 1 ) end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ end_POSTSUBSCRIPT [ italic_γ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_t ) caligraphic_L start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ , italic_γ ( italic_t ) ) ] (19)

Suppose we parameterize γ(t)𝛾𝑡\gamma(t)italic_γ ( italic_t ) with η𝜂\etaitalic_η. Denote θ,η(𝒙0,ϵ,t)=γη(t)θ(𝒙0,ϵ,γη(t))subscript𝜃𝜂subscript𝒙0bold-italic-ϵ𝑡superscriptsubscript𝛾𝜂𝑡subscript𝜃subscript𝒙0bold-italic-ϵsubscript𝛾𝜂𝑡\mathcal{L}_{\theta,\eta}(\bm{x}_{0},\bm{\epsilon},t)=\gamma_{\eta}^{\prime}(t% )\mathcal{L}_{\theta}(\bm{x}_{0},\bm{\epsilon},\gamma_{\eta}(t))caligraphic_L start_POSTSUBSCRIPT italic_θ , italic_η end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ , italic_t ) = italic_γ start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_t ) caligraphic_L start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ , italic_γ start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT ( italic_t ) ), which is a Monte-Carlo estimator of 𝒥FM(θ)subscript𝒥FM𝜃\mathcal{J}_{\text{FM}}(\theta)caligraphic_J start_POSTSUBSCRIPT FM end_POSTSUBSCRIPT ( italic_θ ). Since its variance Vart,ϵ,𝒙0[θ,η(𝒙0,ϵ,t)]=𝔼t,ϵ,𝒙0[θ,η2(𝒙0,ϵ,t)]𝒥FM2(θ)subscriptVar𝑡bold-italic-ϵsubscript𝒙0delimited-[]subscript𝜃𝜂subscript𝒙0bold-italic-ϵ𝑡subscript𝔼𝑡bold-italic-ϵsubscript𝒙0delimited-[]superscriptsubscript𝜃𝜂2subscript𝒙0bold-italic-ϵ𝑡superscriptsubscript𝒥FM2𝜃\mbox{Var}_{t,\bm{\epsilon},\bm{x}_{0}}[\mathcal{L}_{\theta,\eta}(\bm{x}_{0},% \bm{\epsilon},t)]=\mathbb{E}_{t,\bm{\epsilon},\bm{x}_{0}}[\mathcal{L}_{\theta,% \eta}^{2}(\bm{x}_{0},\bm{\epsilon},t)]-\mathcal{J}_{\text{FM}}^{2}(\theta)Var start_POSTSUBSCRIPT italic_t , bold_italic_ϵ , bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ caligraphic_L start_POSTSUBSCRIPT italic_θ , italic_η end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ , italic_t ) ] = blackboard_E start_POSTSUBSCRIPT italic_t , bold_italic_ϵ , bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ caligraphic_L start_POSTSUBSCRIPT italic_θ , italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ , italic_t ) ] - caligraphic_J start_POSTSUBSCRIPT FM end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_θ ) and 𝒥FM(θ)subscript𝒥FM𝜃\mathcal{J}_{\text{FM}}(\theta)caligraphic_J start_POSTSUBSCRIPT FM end_POSTSUBSCRIPT ( italic_θ ) is invariant to γη(t)subscript𝛾𝜂𝑡\gamma_{\eta}(t)italic_γ start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT ( italic_t ), we can minimize 𝔼t,ϵ,𝒙0[θ,η2(𝒙0,ϵ,t)]subscript𝔼𝑡bold-italic-ϵsubscript𝒙0delimited-[]superscriptsubscript𝜃𝜂2subscript𝒙0bold-italic-ϵ𝑡\mathbb{E}_{t,\bm{\epsilon},\bm{x}_{0}}[\mathcal{L}_{\theta,\eta}^{2}(\bm{x}_{% 0},\bm{\epsilon},t)]blackboard_E start_POSTSUBSCRIPT italic_t , bold_italic_ϵ , bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ caligraphic_L start_POSTSUBSCRIPT italic_θ , italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ , italic_t ) ] for variance reduction.

While this approach seeks the optimal IS, it causes extra overhead by introducing an IS network, requiring complex gradient operation or additional training steps. Thus, we only use it as a reference to test the optimality of our designed IS. We simplify the variance reduction in Kingma et al. (2021), and propose an adaptive IS algorithm, which is detailed in Appendix H. Empirically, we show that designed IS is a more preferred approach since it is training-free and achieves a similar convergence speed to learned IS.

5 Related Work

Diffusion models, also known as score-based generative models (SGMs), have achieved state-of-the-art sample quality and likelihood (Dhariwal & Nichol, 2021; Karras et al., 2022; Kingma et al., 2021) among deep generative models, yielding extensive downstream applications such as speech and singing synthesis (Chen et al., 2021; Liu et al., 2022a), conditional image generation (Ramesh et al., 2022; Rombach et al., 2022), guided image editing (Meng et al., 2022; Nichol et al., 2022), unpaired image-to-image translation (Zhao et al., 2022) and inverse problem solving (Chung et al., 2022; Kawar et al., 2022).

Diffusion ODEs are special formulations of neural ODEs and can be viewed as continuous normalizing flows (Chen et al., 2018a). Training of diffusion ODEs can be categorized into simulation-based and simulation-free methods. The former utilizes the exact likelihood evaluation formula of ODE (Chen et al., 2018a), which leads to a maximum likelihood training procedure (Grathwohl et al., 2019). However, it involves expensive ODE simulations for forward and backward propagation and may result in unnecessary complex dynamics (Finlay et al., 2020) since it only cares about the model distribution at t=0𝑡0t=0italic_t = 0. The latter trains neural ODEs by matching their trajectories to a predefined path, such as the diffusion process. This approach is proposed in Song et al. (2021c), and extended in Lu et al. (2022a); Lipman et al. (2022); Albergo & Vanden-Eijnden (2022); Liu et al. (2022b). We propose velocity parameterization which is an extension of Lipman et al. (2022) with practical modifications and claim that the paths used in Lipman et al. (2022); Albergo & Vanden-Eijnden (2022); Liu et al. (2022b) are special cases of noise schedule. Aiming at maximum likelihood training, we also get inspiration from Lu et al. (2022a). We additionally apply likelihood weighting and propose to finetune the model with high-order flow matching.

Variance reduction techniques are commonly used for training diffusion models. Nichol & Dhariwal (2021) proposes an importance sampling (IS) for discrete-time diffusion models by maintaining the historical losses at each time step and building the proposal distribution based on them. Song et al. (2021b) designs an IS to cancel out the weighting before the noise matching loss. Kingma et al. (2021) proposes a variance reduction method that is equivalent to learning a parameterized IS. We simply their procedure and propose an adaptive IS scheme for ablation. By empirically comparing different IS methods, we find a designed and analytical IS distribution that achieves a good performance-efficiency trade-off.

6 Experiments

In this section, we present our training procedure and experiment settings, and our ablation studies to demonstrate how our techniques improve the likelihood of diffusion ODEs.

We implement our methods based on the open-source codebase of Kingma et al. (2021) implemented with JAX Bradbury et al. (2018), and use similar network and hyperparameter settings. We first train the model by optimizing our first-order flow matching objective minθ𝒥FM(θ)subscript𝜃subscript𝒥FM𝜃\min_{\theta}\mathcal{J}_{\text{FM}}(\theta)roman_min start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT caligraphic_J start_POSTSUBSCRIPT FM end_POSTSUBSCRIPT ( italic_θ ) for enough iterations, so that the first-order velocity prediction has little error. Then, we finetune the pretrained first-order model using a mixture of first-order and second-order flow matching objectives minθ𝒥FM(θ)+λ𝒥FM,tr(θ)subscript𝜃subscript𝒥FM𝜃𝜆subscript𝒥FMtr𝜃\min_{\theta}\mathcal{J}_{\text{FM}}(\theta)+\lambda\mathcal{J}_{\text{FM},% \mathrm{tr}}(\theta)roman_min start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT caligraphic_J start_POSTSUBSCRIPT FM end_POSTSUBSCRIPT ( italic_θ ) + italic_λ caligraphic_J start_POSTSUBSCRIPT FM , roman_tr end_POSTSUBSCRIPT ( italic_θ ). The finetune process converges in much fewer iterations than pretraining. Finally, we evaluate the likelihood on the test set using the variational bound under our proposed truncated-normal dequantization. The detailed training configurations are provided in Appendix I.

Our training and evaluation procedure is feasible for any noise schedule αγ,σγsubscript𝛼𝛾subscript𝜎𝛾\alpha_{\gamma},\sigma_{\gamma}italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT. We choose two special noise schedules:

Variance Preserving (VP)

αγ2+σγ2=1superscriptsubscript𝛼𝛾2superscriptsubscript𝜎𝛾21\alpha_{\gamma}^{2}+\sigma_{\gamma}^{2}=1italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = 1. This schedule is widely used in diffusion models, which yields a process with a fixed variance of one when the initial distribution has a unit variance.

Straight Path (SP)

αγ+σγ=1subscript𝛼𝛾subscript𝜎𝛾1\alpha_{\gamma}+\sigma_{\gamma}=1italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT + italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT = 1. This schedule is used in Lipman et al. (2022); Albergo & Vanden-Eijnden (2022); Liu et al. (2022b), where they call it OT path and claim it leads to better dynamics since the pairwise diffusion paths are straight lines. We simply regard it as a special kind of noise schedule.

Under these two schedules, αγ,σγsubscript𝛼𝛾subscript𝜎𝛾\alpha_{\gamma},\sigma_{\gamma}italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT are uniquely determined by γ𝛾\gammaitalic_γ, and we do not have any extra hyperparameters. They also have corresponding objectives and designed IS, which can be expressed in closed form (see Appendix C for details). We train our i-DODE on CIFAR-10 (Krizhevsky et al., 2009) and ImageNet-32111There are two different versions of ImageNet32 and ImageNet64 datasets. For fair comparisons, we use both versions of ImageNet32, one is downloaded from https://image-net.org/data/downsample/Imagenet32_train.zip, following Lipman et al. (2022), and the other is downloaded from http://image-net.org/small/train_32x32.tar (old version, no longer available), following Song et al. (2021b) and Kingma et al. (2021). The former dataset applies anti-aliasing and is easier for maximum likelihood training. (Deng et al., 2009), which are two popular benchmarks for generative modeling and density estimation.

6.1 Likelihood and Samples

Table 1: Negative log-likelihood (NLL) in bits/dim (BPD), sample quality (FID scores) and number of function evaluations (NFE) on CIFAR-10 and ImageNet 32x32. For fair comparisons, we list NLL results of previous ODEs without variational dequantization or data augmentation (unless specifically stated), and FID/NFE results obtained by adaptive-step ODE solver. Results with “///” mean they are not reported in the original papers or do not apply. {}^{\dagger}start_FLOATSUPERSCRIPT † end_FLOATSUPERSCRIPTFor VDM, since they have no ODE formulation, the FID score is obtained by 1000 step discretization of their SDE. We report their corresponding ODE result in the ablation study. *{}^{*}start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPTCorresponding to the old version ImageNet-32 dataset.
Model CIFAR-10 ImageNet-32
NLL \downarrow FID \downarrow NFE \downarrow NLL \downarrow FID \downarrow NFE \downarrow
VDM (Kingma et al., 2021) 2.65 7.60{}^{\dagger}start_FLOATSUPERSCRIPT † end_FLOATSUPERSCRIPT 1000 3.72*{}^{*}start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT /// ///
VDM (with data augmentation)  (Kingma et al., 2021) 2.49 /// /// /// /// ///
(Previous ODE)
FFJORD (Grathwohl et al., 2019) 3.40 /// /// /// /// ///
ScoreSDE (Song et al., 2021c) 2.99 2.92 /// /// /// ///
ScoreFlow (Song et al., 2021b) 2.90 5.40 /// 3.82*{}^{*}start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT 10.18*{}^{*}start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT ///
Soft Truncation (Kim et al., 2022) 3.01 3.96 /// 3.90*{}^{*}start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT 8.42*{}^{*}start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT ///
Flow Matching (Lipman et al., 2022) 2.99 6.35 142 3.53 5.31 122
Stochastic Interp.(Albergo & Vanden-Eijnden, 2022) 2.99 10.27 /// 3.48 8.49 ///
i-DODE (SP) (ours) 2.56 11.20 162 3.44/3.69*{}^{*}start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT 10.31 138
i-DODE (VP) (ours) 2.57 10.74 126 3.43/3.70*{}^{*}start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT 9.09 152
i-DODE (VP, with data augmentation) (ours) 2.42 3.76 215 /// /// ///
Refer to caption
Figure 1: Test loss curve in the pretraining phase, compared to VDM (Kingma et al., 2021). We compute the loss on the test set by the SDE likelihood bound in Kingma et al. (2021).

Table 1 shows our experiment results on CIFAR-10 and ImageNet-32 datasets. Our models are pretrained with velocity parameterization, designed IS, and finetuned with second-order flow matching. We report the likelihood values using our truncated-normal dequantization with the importance-weighted estimator under K=20𝐾20K=20italic_K = 20. To compute the FID values, we apply an adaptive-step ODE solver to draw samples from the diffusion ODEs. We also report the NFE during the sampling process, which reflects the smoothness of the dynamics.

Combining our training techniques and dequantization, we exceed the likelihood of previous ODEs, especially by a large margin on CIFAR-10. In Figure 1, we compare our pretraining phase to VDM (Kingma et al., 2021), which indicates that our techniques achieve 2xsimilar-to\sim3x times of previous convergence speed. We further strengthen the likelihood results by employing data augmentation techniques and a larger network, following VDM. We observe that augmented training data may cause fluctuations in the training and testing losses. When we select the models that achieve the best testing performance, we obtain an SDE likelihood of 2.46 at around only 2M iterations, compared to 2.49 of VDM at 10M iterations.

We do not observe the superiority of SP to VP such as lower FID and NFE as in Lipman et al. (2022). We suspect it may result from maximum likelihood training, which puts more emphasis on the high log-SNR region. More theoretical comparisons with Lipman et al. (2022) are given in Appendix E.2.

Randomly generated samples from our models are provided in Appendix J. Since we use network architecture and techniques targeted at the likelihood, our FID is worse than the state-of-the-art, which can be improved by designing time weighting to emphasize the training at small log-SNR levels (Kingma et al., 2021) or using high-quality sampling algorithms such as PC sampler (Song et al., 2021c). Besides, data augmentation and a larger network notably improve the FID to 3.76 on CIFAR-10, while achieving the state-of-the-art likelihood.

6.2 Ablations

Refer to caption
Figure 2: Training curve from scratch for ablation. We compute the loss on the training set by the SDE likelihood bound in Kingma et al. (2021).
Refer to caption (a) γ(t)𝛾𝑡\gamma(t)italic_γ ( italic_t )
Refer to caption (b) Variance at different log-SNR levels.
Figure 3: Visualization of importance sampling: (a) The inverse cumulative distribution function γ(t)𝛾𝑡\gamma(t)italic_γ ( italic_t ) of the proposal distribution p(γ)𝑝𝛾p(\gamma)italic_p ( italic_γ ), which maps uniform t𝑡titalic_t to importance sampled γ𝛾\gammaitalic_γ (b) The variance of Monte-Carlo estimator Var[γ(t)θ(𝒙0,ϵ,γ(t))]Vardelimited-[]superscript𝛾𝑡subscript𝜃subscript𝒙0bold-italic-ϵ𝛾𝑡\mbox{Var}\left[\gamma^{\prime}(t)\mathcal{L}_{\theta}(\bm{x}_{0},\bm{\epsilon% },\gamma(t))\right]Var [ italic_γ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_t ) caligraphic_L start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ , italic_γ ( italic_t ) ) ] at different noise levels, estimated using 32 data samples 𝒙0subscript𝒙0\bm{x}_{0}bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT and 100 noise samples ϵbold-italic-ϵ\bm{\epsilon}bold_italic_ϵ. The peak variance is achieved around γ=11.2𝛾11.2\gamma=-11.2italic_γ = - 11.2.

Due to the expensive time cost of pretraining, we only conduct ablation studies on CIFAR-10 under the VP schedule. First, we test our techniques for pretraining when training from scratch. We plot the training curves with noise predictor (Kingma et al., 2021) and velocity predictor, then further implement our IS strategies (Figure 2). We find that velocity parameterization and IS both accelerate the training process, while designed IS performs slightly worse than adaptive IS. Considering the extra time cost for learning the IS network, we conclude that designed IS is a better choice for large-scale pretraining. Then we visualize different IS by plotting the mapping from uniform t𝑡titalic_t to importance sampled γ𝛾\gammaitalic_γ, as well as the variance at different noise levels on the pretrained model (Figure 3). We show that the IS reduces the variance by sampling more in high log-SNR regions.

Table 2: Ablation study when converged. We report negative log-likelihood (NLL) in bits/dim (BPD), sample quality (FID scores), and number of function evaluations (NFE) after our pretraining and finetuning phase. We evaluate NLL by uniform (U) and truncated-normal (TN) dequantization without importance weight. We retrain VDM and evaluate its ODE form.
Model NLL (U) NLL (TN) FID NFE
VDM (Kingma et al., 2021) 2.78 2.64 8.65 213
Pretrain (ours) 2.75 2.61 10.66 248
+ Finetune (ours) 2.74 2.60 10.74 126

Next, we test our pretraining, finetuning and evaluation on the converged model (Table 2). As stated before, our pretraining has faster loss descent and converges to a higher likelihood than VDM. Based on it, our finetuning slightly improves the ODE likelihood and smooths the flow, leading to much less NFE when sampling. Our truncated-normal dequantization is also a key factor for precise likelihood computing, which surpasses previous uniform dequantization by a large margin.

In agreement with Song et al. (2021b), our improvements in likelihood lead to slightly worse FIDs. We also argue that the degeneration is small in terms of visual quality. We provide additional samples in Appendix J for comparison.

7 Conclusion

We propose improved techniques for simulation-free maximum likelihood training and likelihood evaluation of diffusion ODEs. Our training stage involves improved pretraining and additional finetuning, which results in fast convergence, high likelihood and smooth trajectory. We improve the likelihood evaluation with novel truncated-normal dequantization, which is training-free and tailored for diffusion ODEs. Empirically, we achieve state-of-the-art likelihood on image datasets without variational dequantization or data augmentation and make a breakthrough on CIFAR-10 compared to previous ODEs. Due to resource limitations, we didn’t explore tuning of hyperparameters and network architectures, which are left for future work.

Acknowledgements

This work was supported by the National Key Research and Development Program of China (2020AAA0106302); NSF of China Projects (Nos. 62061136001, 61620106010, 62076145, U19B2034, U1811461, U19A2081, 6197222, 62106120, 62076145); a grant from Tsinghua Institute for Guo Qiang; the High Performance Computing Center, Tsinghua University. J.Z was also supported by the New Cornerstone Science Foundation through the XPLORER PRIZE. The large-scale training was supported by Shengshu Technology.

References

  • Albergo & Vanden-Eijnden (2022) Albergo, M. S. and Vanden-Eijnden, E. Building normalizing flows with stochastic interpolants. In The Eleventh International Conference on Learning Representations, 2022.
  • Anderson (1982) Anderson, B. D. Reverse-time diffusion equation models. Stochastic Processes and their Applications, 12(3):313–326, 1982.
  • Bradbury et al. (2018) Bradbury, J., Frostig, R., Hawkins, P., Johnson, M. J., Leary, C., Maclaurin, D., Necula, G., Paszke, A., VanderPlas, J., Wanderman-Milne, S., et al. Jax: composable transformations of python+ numpy programs. Version 0.2, 5:14–24, 2018.
  • Burda et al. (2015) Burda, Y., Grosse, R., and Salakhutdinov, R. Importance weighted autoencoders. arXiv preprint arXiv:1509.00519, 2015.
  • Chen et al. (2021) Chen, N., Zhang, Y., Zen, H., Weiss, R. J., Norouzi, M., and Chan, W. Wavegrad: Estimating gradients for waveform generation. In International Conference on Learning Representations, 2021.
  • Chen et al. (2018a) Chen, R. T., Rubanova, Y., Bettencourt, J., and Duvenaud, D. Neural ordinary differential equations. In Proceedings of the 32nd International Conference on Neural Information Processing Systems, pp.  6572–6583, 2018a.
  • Chen et al. (2018b) Chen, X., Mishra, N., Rohaninejad, M., and Abbeel, P. Pixelsnail: An improved autoregressive generative model. In International Conference on Machine Learning, pp. 864–872. PMLR, 2018b.
  • Chen et al. (2018c) Chen, Z., Yeo, C. K., Lee, B. S., and Lau, C. T. Autoencoder-based network anomaly detection. In 2018 Wireless telecommunications symposium (WTS), pp. 1–5. IEEE, 2018c.
  • Choi et al. (2022) Choi, K., Meng, C., Song, Y., and Ermon, S. Density ratio estimation via infinitesimal classification. In International Conference on Artificial Intelligence and Statistics, pp.  2552–2573. PMLR, 2022.
  • Chung et al. (2022) Chung, H., Kim, J., Mccann, M. T., Klasky, M. L., and Ye, J. C. Diffusion posterior sampling for general noisy inverse problems. In The Eleventh International Conference on Learning Representations, 2022.
  • Deng et al. (2009) Deng, J., Dong, W., Socher, R., Li, L., Li, K., and Fei-Fei, L. ImageNet: A large-scale hierarchical image database. In 2009 IEEE Conference on Computer Vision and Pattern Recognition, pp.  248–255. IEEE, 2009.
  • Dhariwal & Nichol (2021) Dhariwal, P. and Nichol, A. Q. Diffusion models beat GANs on image synthesis. In Advances in Neural Information Processing Systems, volume 34, pp.  8780–8794, 2021.
  • Dias et al. (2020) Dias, M. L., Mattos, C. L. C., da Silva, T. L., de Macedo, J. A. F., and Silva, W. C. Anomaly detection in trajectory data with normalizing flows. In 2020 International Joint Conference on Neural Networks (IJCNN), pp.  1–8. IEEE, 2020.
  • Dinh et al. (2017) Dinh, L., Sohl-Dickstein, J., and Bengio, S. Density estimation using real nvp. In International Conference on Learning Representations, 2017.
  • Dormand & Prince (1980) Dormand, J. R. and Prince, P. J. A family of embedded Runge-Kutta formulae. Journal of computational and applied mathematics, 6(1):19–26, 1980.
  • Finlay et al. (2020) Finlay, C., Jacobsen, J.-H., Nurbekyan, L., and Oberman, A. How to train your neural ode: the world of jacobian and kinetic regularization. In International conference on machine learning, pp. 3154–3164. PMLR, 2020.
  • Grathwohl et al. (2019) Grathwohl, W., Chen, R. T., Bettencourt, J., Sutskever, I., and Duvenaud, D. Ffjord: Free-form continuous dynamics for scalable reversible generative models. In International Conference on Learning Representations, 2019.
  • Helminger et al. (2020) Helminger, L., Djelouah, A., Gross, M., and Schroers, C. Lossy image compression with normalizing flows. arXiv preprint arXiv:2008.10486, 2020.
  • Ho et al. (2019) Ho, J., Chen, X., Srinivas, A., Duan, Y., and Abbeel, P. Flow++: Improving flow-based generative models with variational dequantization and architecture design. In International Conference on Machine Learning, pp. 2722–2730. PMLR, 2019.
  • Ho et al. (2020) Ho, J., Jain, A., and Abbeel, P. Denoising diffusion probabilistic models. In Advances in Neural Information Processing Systems, volume 33, pp.  6840–6851, 2020.
  • Ho et al. (2022) Ho, J., Chan, W., Saharia, C., Whang, J., Gao, R., Gritsenko, A., Kingma, D. P., Poole, B., Norouzi, M., Fleet, D. J., et al. Imagen video: High definition video generation with diffusion models. arXiv preprint arXiv:2210.02303, 2022.
  • Ho et al. (2021) Ho, Y.-H., Chan, C.-C., Peng, W.-H., Hang, H.-M., and Domański, M. Anfic: Image compression using augmented normalizing flows. IEEE Open Journal of Circuits and Systems, 2:613–626, 2021.
  • Huang et al. (2021) Huang, C.-W., Lim, J. H., and Courville, A. A variational perspective on diffusion-based generative models and score matching. In Advances in Neural Information Processing Systems, 2021.
  • Hutchinson (1990) Hutchinson, M. F. A stochastic estimator of the trace of the influence matrix for laplacian smoothing splines. Communications in Statistics-Simulation and Computation, 19(2):433–450, 1990.
  • Karras et al. (2022) Karras, T., Aittala, M., Aila, T., and Laine, S. Elucidating the design space of diffusion-based generative models. In Advances in Neural Information Processing Systems, 2022.
  • Kawar et al. (2022) Kawar, B., Elad, M., Ermon, S., and Song, J. Denoising diffusion restoration models. In Advances in Neural Information Processing Systems, 2022.
  • Kim et al. (2022) Kim, D., Shin, S., Song, K., Kang, W., and Moon, I.-C. Soft truncation: A universal training technique of score-based diffusion model for high precision score estimation. In International Conference on Machine Learning, pp. 11201–11228. PMLR, 2022.
  • Kingma & Ba (2014) Kingma, D. P. and Ba, J. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.
  • Kingma & Dhariwal (2018) Kingma, D. P. and Dhariwal, P. Glow: generative flow with invertible 1×\times× 1 convolutions. In Proceedings of the 32nd International Conference on Neural Information Processing Systems, pp.  10236–10245, 2018.
  • Kingma & Welling (2014) Kingma, D. P. and Welling, M. Auto-encoding variational bayes. In International Conference on Learning Representations, 2014.
  • Kingma et al. (2021) Kingma, D. P., Salimans, T., Poole, B., and Ho, J. Variational diffusion models. In Advances in Neural Information Processing Systems, 2021.
  • Krizhevsky et al. (2009) Krizhevsky, A., Hinton, G., et al. Learning multiple layers of features from tiny images. 2009.
  • Lipman et al. (2022) Lipman, Y., Chen, R. T., Ben-Hamu, H., Nickel, M., and Le, M. Flow matching for generative modeling. In The Eleventh International Conference on Learning Representations, 2022.
  • Liu et al. (2022a) Liu, J., Li, C., Ren, Y., Chen, F., and Zhao, Z. Diffsinger: Singing voice synthesis via shallow diffusion mechanism. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 36, pp.  11020–11028, 2022a.
  • Liu et al. (2022b) Liu, X., Gong, C., et al. Flow straight and fast: Learning to generate and transfer data with rectified flow. In The Eleventh International Conference on Learning Representations, 2022b.
  • Loshchilov & Hutter (2019) Loshchilov, I. and Hutter, F. Decoupled weight decay regularization. In International Conference on Learning Representations, 2019.
  • Lu et al. (2022a) Lu, C., Zheng, K., Bao, F., Chen, J., Li, C., and Zhu, J. Maximum likelihood training for score-based diffusion odes by high order denoising score matching. In International Conference on Machine Learning, pp. 14429–14460. PMLR, 2022a.
  • Lu et al. (2022b) Lu, C., Zhou, Y., Bao, F., Chen, J., Li, C., and Zhu, J. Dpm-solver: A fast ode solver for diffusion probabilistic model sampling in around 10 steps. In Advances in Neural Information Processing Systems, 2022b.
  • Meng et al. (2022) Meng, C., Song, Y., Song, J., Wu, J., Zhu, J.-Y., and Ermon, S. SDEdit: Image synthesis and editing with stochastic differential equations. In International Conference on Learning Representations, 2022.
  • Nichol & Dhariwal (2021) Nichol, A. Q. and Dhariwal, P. Improved denoising diffusion probabilistic models. In International Conference on Machine Learning, pp. 8162–8171. PMLR, 2021.
  • Nichol et al. (2022) Nichol, A. Q., Dhariwal, P., Ramesh, A., Shyam, P., Mishkin, P., Mcgrew, B., Sutskever, I., and Chen, M. Glide: Towards photorealistic image generation and editing with text-guided diffusion models. In International Conference on Machine Learning, pp. 16784–16804. PMLR, 2022.
  • Oord et al. (2016) Oord, A. v. d., Kalchbrenner, N., Vinyals, O., Espeholt, L., Graves, A., and Kavukcuoglu, K. Conditional image generation with pixelcnn decoders. In Proceedings of the 30th International Conference on Neural Information Processing Systems, pp.  4797–4805, 2016.
  • Ramesh et al. (2022) Ramesh, A., Dhariwal, P., Nichol, A., Chu, C., and Chen, M. Hierarchical text-conditional image generation with CLIP latents. arXiv preprint arXiv:2204.06125, 2022.
  • Rombach et al. (2022) Rombach, R., Blattmann, A., Lorenz, D., Esser, P., and Ommer, B. High-resolution image synthesis with latent diffusion models. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp.  10684–10695, 2022.
  • Salimans & Ho (2022) Salimans, T. and Ho, J. Progressive distillation for fast sampling of diffusion models. In International Conference on Learning Representations, 2022.
  • Salimans et al. (2017) Salimans, T., Karpathy, A., Chen, X., and Kingma, D. P. Pixelcnn++: Improving the pixelcnn with discretized logistic mixture likelihood and other modifications. In International Conference on Learning Representations, 2017.
  • Serrà et al. (2020) Serrà, J., Álvarez, D., Gómez, V., Slizovskaia, O., Núñez, J. F., and Luque, J. Input complexity and out-of-distribution detection with likelihood-based generative models. In International Conference on Learning Representations, 2020.
  • Sohl-Dickstein et al. (2015) Sohl-Dickstein, J., Weiss, E., Maheswaranathan, N., and Ganguli, S. Deep unsupervised learning using nonequilibrium thermodynamics. In International Conference on Machine Learning, pp. 2256–2265. PMLR, 2015.
  • Song et al. (2021a) Song, J., Meng, C., and Ermon, S. Denoising diffusion implicit models. In International Conference on Learning Representations, 2021a.
  • Song & Ermon (2019) Song, Y. and Ermon, S. Generative modeling by estimating gradients of the data distribution. In Advances in Neural Information Processing Systems, volume 32, pp.  11895–11907, 2019.
  • Song et al. (2020) Song, Y., Garg, S., Shi, J., and Ermon, S. Sliced score matching: A scalable approach to density and score estimation. In Uncertainty in Artificial Intelligence, pp.  574–584. PMLR, 2020.
  • Song et al. (2021b) Song, Y., Durkan, C., Murray, I., and Ermon, S. Maximum likelihood training of score-based diffusion models. In Advances in Neural Information Processing Systems, volume 34, pp.  1415–1428, 2021b.
  • Song et al. (2021c) Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., and Poole, B. Score-based generative modeling through stochastic differential equations. In International Conference on Learning Representations, 2021c.
  • Uria et al. (2013) Uria, B., Murray, I., and Larochelle, H. RNADE: The real-valued neural autoregressive density-estimator. Advances in Neural Information Processing Systems, 26, 2013.
  • Vahdat & Kautz (2020) Vahdat, A. and Kautz, J. Nvae: a deep hierarchical variational autoencoder. In Proceedings of the 34th International Conference on Neural Information Processing Systems, pp.  19667–19679, 2020.
  • Vincent (2011) Vincent, P. A connection between score matching and denoising autoencoders. Neural computation, 23(7):1661–1674, 2011.
  • Xiao et al. (2020) Xiao, Z., Yan, Q., and Amit, Y. Likelihood regret: an out-of-distribution detection score for variational auto-encoder. In Proceedings of the 34th International Conference on Neural Information Processing Systems, pp.  20685–20696, 2020.
  • Xu et al. (2022) Xu, Y., Liu, Z., Tegmark, M., and Jaakkola, T. S. Poisson flow generative models. In Advances in Neural Information Processing Systems, 2022.
  • Xu et al. (2023) Xu, Y., Liu, Z., Tian, Y., Tong, S., Tegmark, M., and Jaakkola, T. Pfgm++: Unlocking the potential of physics-inspired generative models. arXiv preprint arXiv:2302.04265, 2023.
  • Yang & Mandt (2022) Yang, R. and Mandt, S. Lossy image compression with conditional diffusion models. arXiv preprint arXiv:2209.06950, 2022.
  • Zhao et al. (2022) Zhao, M., Bao, F., Li, C., and Zhu, J. Egsde: Unpaired image-to-image translation via energy-guided stochastic differential equations. In Advances in Neural Information Processing Systems, 2022.

Appendix A Different perspectives of diffusion ODEs for bridging the gap between discrete and continuous data

Suppose the discrete data 𝑿0subscript𝑿0\bm{X}_{0}bold_italic_X start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT to be modelled are 8-bit integers {0,1,,255}01255\{0,1,\dots,255\}{ 0 , 1 , … , 255 }. Following the common transform in diffusion models, we first normalize it to range [-1,1] by the mapping 𝒙0=𝑿0+12128128subscript𝒙0subscript𝑿012128128\bm{x}_{0}=\frac{\bm{X}_{0}+\frac{1}{2}-128}{128}bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = divide start_ARG bold_italic_X start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + divide start_ARG 1 end_ARG start_ARG 2 end_ARG - 128 end_ARG start_ARG 128 end_ARG. In the following discussions, we consider the model distribution P0(𝒙0)subscript𝑃0subscript𝒙0P_{0}(\bm{x}_{0})italic_P start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) on transformed discrete data 𝒙0subscript𝒙0\bm{x}_{0}bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, which is equal to P0(𝑿0)subscript𝑃0subscript𝑿0P_{0}(\bm{X}_{0})italic_P start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) since the scaling does not alter the discrete probability.

A.1 Dequantization perspective

The discrete data 𝒙0subscript𝒙0\bm{x}_{0}bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT has a uniform gap 11281128\frac{1}{128}divide start_ARG 1 end_ARG start_ARG 128 end_ARG between two consecutive values on each dimension. We can define the discrete model distribution as

P0(𝒙0)=𝒖[1256,1256]dpϵ(𝒙+𝒖)d𝒖subscript𝑃0subscript𝒙0subscript𝒖superscript12561256𝑑subscript𝑝italic-ϵ𝒙𝒖differential-d𝒖P_{0}(\bm{x}_{0})=\int_{\bm{u}\in[-\frac{1}{256},\frac{1}{256}]^{d}}p_{% \epsilon}(\bm{x}+\bm{u})\mathrm{d}\bm{u}italic_P start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = ∫ start_POSTSUBSCRIPT bold_italic_u ∈ [ - divide start_ARG 1 end_ARG start_ARG 256 end_ARG , divide start_ARG 1 end_ARG start_ARG 256 end_ARG ] start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x + bold_italic_u ) roman_d bold_italic_u (20)

where pϵsubscript𝑝italic-ϵp_{\epsilon}italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT is the diffusion ODE defined at time ϵitalic-ϵ\epsilonitalic_ϵ. Then, we can introduce a dequantization distribution q(𝒖|𝒙0)𝑞conditional𝒖subscript𝒙0q(\bm{u}|\bm{x}_{0})italic_q ( bold_italic_u | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) with support over [1256,1256)dsuperscript12561256𝑑[-\frac{1}{256},\frac{1}{256})^{d}[ - divide start_ARG 1 end_ARG start_ARG 256 end_ARG , divide start_ARG 1 end_ARG start_ARG 256 end_ARG ) start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT. Treating q𝑞qitalic_q as an approximate posterior, we obtain the following variational bound (Ho et al., 2019):

logP0(𝒙0)𝔼q(𝒖|𝒙0)[logpϵ(𝒙0+𝒖)logq(𝒖|𝒙0)]subscript𝑃0subscript𝒙0subscript𝔼𝑞conditional𝒖subscript𝒙0delimited-[]subscript𝑝italic-ϵsubscript𝒙0𝒖𝑞conditional𝒖subscript𝒙0\log P_{0}(\bm{x}_{0})\geq\mathbb{E}_{q(\bm{u}|\bm{x}_{0})}\left[\log p_{% \epsilon}(\bm{x}_{0}+\bm{u})-\log q(\bm{u}|\bm{x}_{0})\right]roman_log italic_P start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ≥ blackboard_E start_POSTSUBSCRIPT italic_q ( bold_italic_u | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + bold_italic_u ) - roman_log italic_q ( bold_italic_u | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ] (21)

The ODE term logpϵ(𝒙0+𝒖)subscript𝑝italic-ϵsubscript𝒙0𝒖\log p_{\epsilon}(\bm{x}_{0}+\bm{u})roman_log italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + bold_italic_u ) can be evaluated exactly by solving another ODE called “Instantaneous Change of Variables” (Chen et al., 2018a). As for the posterior logq(𝒖|𝒙0)𝑞conditional𝒖subscript𝒙0\log q(\bm{u}|\bm{x}_{0})roman_log italic_q ( bold_italic_u | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ), we can derive closed-form solutions for predefined posterior formulation. We provide the details for uniform dequantization and our proposed truncated-normal dequantization.

Uniform dequantization

We simply use uniform posterior q(𝒖|𝒙0)=𝒰(1256,1256)𝑞conditional𝒖subscript𝒙0𝒰12561256q(\bm{u}|\bm{x}_{0})=\mathcal{U}(-\frac{1}{256},\frac{1}{256})italic_q ( bold_italic_u | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = caligraphic_U ( - divide start_ARG 1 end_ARG start_ARG 256 end_ARG , divide start_ARG 1 end_ARG start_ARG 256 end_ARG ). In this case, logq(𝒖|𝒙0)=dlog128𝑞conditional𝒖subscript𝒙0𝑑128\log q(\bm{u}|\bm{x}_{0})=d\log 128roman_log italic_q ( bold_italic_u | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = italic_d roman_log 128 is a constant, and the bound becomes

logP0(𝒙0)𝔼𝒖𝒰(1256,1256)[logpϵ(𝒙0+𝒖)]dlog128subscript𝑃0subscript𝒙0subscript𝔼similar-to𝒖𝒰12561256delimited-[]subscript𝑝italic-ϵsubscript𝒙0𝒖𝑑128\log P_{0}(\bm{x}_{0})\geq\mathbb{E}_{\bm{u}\sim\mathcal{U}(-\frac{1}{256},% \frac{1}{256})}\left[\log p_{\epsilon}(\bm{x}_{0}+\bm{u})\right]-d\log 128roman_log italic_P start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ≥ blackboard_E start_POSTSUBSCRIPT bold_italic_u ∼ caligraphic_U ( - divide start_ARG 1 end_ARG start_ARG 256 end_ARG , divide start_ARG 1 end_ARG start_ARG 256 end_ARG ) end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + bold_italic_u ) ] - italic_d roman_log 128 (22)

Similar to Burda et al. (2015), we can also sample multiple 𝒖𝒖\bm{u}bold_italic_u to derive a tighter bound, which is called importance weighted estimator:

logP0(𝒙0)𝔼𝒖(1),,𝒖(K)𝒰(1256,1256)[log(1Ki=1Kpϵ(𝒙0+𝒖(i)))]dlog128subscript𝑃0subscript𝒙0subscript𝔼similar-tosuperscript𝒖1superscript𝒖𝐾𝒰12561256delimited-[]1𝐾superscriptsubscript𝑖1𝐾subscript𝑝italic-ϵsubscript𝒙0superscript𝒖𝑖𝑑128\log P_{0}(\bm{x}_{0})\geq\mathbb{E}_{\bm{u}^{(1)},\dots,\bm{u}^{(K)}\sim% \mathcal{U}(-\frac{1}{256},\frac{1}{256})}\left[\log\left(\frac{1}{K}\sum_{i=1% }^{K}p_{\epsilon}(\bm{x}_{0}+\bm{u}^{(i)})\right)\right]-d\log 128roman_log italic_P start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ≥ blackboard_E start_POSTSUBSCRIPT bold_italic_u start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT , … , bold_italic_u start_POSTSUPERSCRIPT ( italic_K ) end_POSTSUPERSCRIPT ∼ caligraphic_U ( - divide start_ARG 1 end_ARG start_ARG 256 end_ARG , divide start_ARG 1 end_ARG start_ARG 256 end_ARG ) end_POSTSUBSCRIPT [ roman_log ( divide start_ARG 1 end_ARG start_ARG italic_K end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + bold_italic_u start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) ) ] - italic_d roman_log 128 (23)

However, this dequantization will cause a training-evaluation gap. For training, we fit pϵsubscript𝑝italic-ϵp_{\epsilon}italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT to the distribution of 𝒙ϵ=αϵ𝒙0+σϵϵ,ϵ𝒩(𝟎,𝑰)formulae-sequencesubscript𝒙italic-ϵsubscript𝛼italic-ϵsubscript𝒙0subscript𝜎italic-ϵbold-italic-ϵsimilar-tobold-italic-ϵ𝒩0𝑰\bm{x}_{\epsilon}=\alpha_{\epsilon}\bm{x}_{0}+\sigma_{\epsilon}\bm{\epsilon},% \bm{\epsilon}\sim\mathcal{N}(\mathbf{0},\bm{I})bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT = italic_α start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT bold_italic_ϵ , bold_italic_ϵ ∼ caligraphic_N ( bold_0 , bold_italic_I ). For evaluation, we test pϵsubscript𝑝italic-ϵp_{\epsilon}italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT on uniform dequantized 𝒙0+𝒖,𝒖𝒰(1256,1256)similar-tosubscript𝒙0𝒖𝒖𝒰12561256\bm{x}_{0}+\bm{u},\bm{u}\sim\mathcal{U}(-\frac{1}{256},\frac{1}{256})bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + bold_italic_u , bold_italic_u ∼ caligraphic_U ( - divide start_ARG 1 end_ARG start_ARG 256 end_ARG , divide start_ARG 1 end_ARG start_ARG 256 end_ARG ). This gap will degenerate the likelihood performance, as we will show later.

Truncated-normal dequantization

To bridge the training-evaluation gap, we test pϵsubscript𝑝italic-ϵp_{\epsilon}italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT on 𝒙^ϵ=αϵ𝒙0+σϵϵ^subscript^𝒙italic-ϵsubscript𝛼italic-ϵsubscript𝒙0subscript𝜎italic-ϵ^bold-italic-ϵ\hat{\bm{x}}_{\epsilon}=\alpha_{\epsilon}\bm{x}_{0}+\sigma_{\epsilon}\hat{\bm{% \epsilon}}over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT = italic_α start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT over^ start_ARG bold_italic_ϵ end_ARG, where ϵ^^bold-italic-ϵ\hat{\bm{\epsilon}}over^ start_ARG bold_italic_ϵ end_ARG obeys a truncated-normal distribution to make sure the range of 𝒖𝒖\bm{u}bold_italic_u on each dimension does not exceed [1256,1256]12561256[-\frac{1}{256},\frac{1}{256}][ - divide start_ARG 1 end_ARG start_ARG 256 end_ARG , divide start_ARG 1 end_ARG start_ARG 256 end_ARG ]. Specifically, denote ταϵ256σϵ𝜏subscript𝛼italic-ϵ256subscript𝜎italic-ϵ\tau\coloneqq\frac{\alpha_{\epsilon}}{256\sigma_{\epsilon}}italic_τ ≔ divide start_ARG italic_α start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT end_ARG start_ARG 256 italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT end_ARG, we define the truncated-normal distribution as

ϵ^𝒯𝒩(ϵ^|𝟎,𝑰,τ,τ)similar-to^bold-italic-ϵ𝒯𝒩conditional^bold-italic-ϵ0𝑰𝜏𝜏\hat{\bm{\epsilon}}\sim\mathcal{T}\mathcal{N}\left(\hat{\bm{\epsilon}}\left|% \bm{0},\bm{I},-\tau,\tau\right.\right)over^ start_ARG bold_italic_ϵ end_ARG ∼ caligraphic_T caligraphic_N ( over^ start_ARG bold_italic_ϵ end_ARG | bold_0 , bold_italic_I , - italic_τ , italic_τ ) (24)

Let

𝒖σϵαϵϵ^[1256,1256]𝒖subscript𝜎italic-ϵsubscript𝛼italic-ϵ^bold-italic-ϵ12561256\bm{u}\coloneqq\frac{\sigma_{\epsilon}}{\alpha_{\epsilon}}\hat{\bm{\epsilon}}% \in\left[-\frac{1}{256},\frac{1}{256}\right]bold_italic_u ≔ divide start_ARG italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT end_ARG over^ start_ARG bold_italic_ϵ end_ARG ∈ [ - divide start_ARG 1 end_ARG start_ARG 256 end_ARG , divide start_ARG 1 end_ARG start_ARG 256 end_ARG ] (25)

By the change of variables for probability density, we have

logpϵ(𝒙0+𝒖)=logpϵ(𝒙0+σϵαϵϵ^)=logpϵ(𝒙^ϵ)+dlogαϵsubscript𝑝italic-ϵsubscript𝒙0𝒖subscript𝑝italic-ϵsubscript𝒙0subscript𝜎italic-ϵsubscript𝛼italic-ϵ^bold-italic-ϵsubscript𝑝italic-ϵsubscript^𝒙italic-ϵ𝑑subscript𝛼italic-ϵ\log p_{\epsilon}(\bm{x}_{0}+\bm{u})=\log p_{\epsilon}\left(\bm{x}_{0}+\frac{% \sigma_{\epsilon}}{\alpha_{\epsilon}}\hat{\bm{\epsilon}}\right)=\log p_{% \epsilon}(\hat{\bm{x}}_{\epsilon})+d\log\alpha_{\epsilon}roman_log italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + bold_italic_u ) = roman_log italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + divide start_ARG italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT end_ARG over^ start_ARG bold_italic_ϵ end_ARG ) = roman_log italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) + italic_d roman_log italic_α start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT (26)
logq(𝒖|𝒙0)=logq(σϵαϵϵ^)=logq(ϵ^)+dlogαϵσϵ𝑞conditional𝒖subscript𝒙0𝑞subscript𝜎italic-ϵsubscript𝛼italic-ϵ^bold-italic-ϵ𝑞^bold-italic-ϵ𝑑subscript𝛼italic-ϵsubscript𝜎italic-ϵ\log q(\bm{u}|\bm{x}_{0})=\log q\left(\frac{\sigma_{\epsilon}}{\alpha_{% \epsilon}}\hat{\bm{\epsilon}}\right)=\log q(\hat{\bm{\epsilon}})+d\log\frac{% \alpha_{\epsilon}}{\sigma_{\epsilon}}roman_log italic_q ( bold_italic_u | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = roman_log italic_q ( divide start_ARG italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT end_ARG over^ start_ARG bold_italic_ϵ end_ARG ) = roman_log italic_q ( over^ start_ARG bold_italic_ϵ end_ARG ) + italic_d roman_log divide start_ARG italic_α start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT end_ARG (27)

where q(ϵ^)𝑞^bold-italic-ϵq(\hat{\bm{\epsilon}})italic_q ( over^ start_ARG bold_italic_ϵ end_ARG ) is the probability distribution function of truncated-normal distributions

q(ϵ^)=1(2πZ2)d2exp(12ϵ^22),ZΦ(τ)Φ(τ)=erf(τ2)formulae-sequence𝑞^bold-italic-ϵ1superscript2𝜋superscript𝑍2𝑑212superscriptsubscriptnorm^bold-italic-ϵ22𝑍Φ𝜏Φ𝜏erf𝜏2q(\hat{\bm{\epsilon}})=\frac{1}{(2\pi Z^{2})^{\frac{d}{2}}}\exp(-\frac{1}{2}\|% \hat{\bm{\epsilon}}\|_{2}^{2}),\quad Z\coloneqq\Phi(\tau)-\Phi(-\tau)=\text{% erf}\left(\frac{\tau}{\sqrt{2}}\right)italic_q ( over^ start_ARG bold_italic_ϵ end_ARG ) = divide start_ARG 1 end_ARG start_ARG ( 2 italic_π italic_Z start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT divide start_ARG italic_d end_ARG start_ARG 2 end_ARG end_POSTSUPERSCRIPT end_ARG roman_exp ( - divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∥ over^ start_ARG bold_italic_ϵ end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) , italic_Z ≔ roman_Φ ( italic_τ ) - roman_Φ ( - italic_τ ) = erf ( divide start_ARG italic_τ end_ARG start_ARG square-root start_ARG 2 end_ARG end_ARG ) (28)

Here Φ()Φ\Phi(\cdot)roman_Φ ( ⋅ ) is the cumulative distribution function of standard normal distribution, and erf()erf\text{erf}(\cdot)erf ( ⋅ ) is the error function. Combining the equations above, the bound is reduced to

logP0(𝒙0)𝔼q(ϵ^)[logpϵ(𝒙^ϵ)logq(ϵ^)]+dlogσϵsubscript𝑃0subscript𝒙0subscript𝔼𝑞^bold-italic-ϵdelimited-[]subscript𝑝italic-ϵsubscript^𝒙italic-ϵ𝑞^bold-italic-ϵ𝑑subscript𝜎italic-ϵ\log P_{0}(\bm{x}_{0})\geq\mathbb{E}_{q(\hat{\bm{\epsilon}})}\left[\log p_{% \epsilon}(\hat{\bm{x}}_{\epsilon})-\log q(\hat{\bm{\epsilon}})\right]+d\log% \sigma_{\epsilon}roman_log italic_P start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ≥ blackboard_E start_POSTSUBSCRIPT italic_q ( over^ start_ARG bold_italic_ϵ end_ARG ) end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) - roman_log italic_q ( over^ start_ARG bold_italic_ϵ end_ARG ) ] + italic_d roman_log italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT (29)

Further, we can derive closed-form solutions for the entropy term of truncated-normal distribution:

𝔼q(ϵ^)[logq(ϵ^)]=(q(ϵ^))=dlog(2πe)+dlogZdτ2πZexp(12τ2)subscript𝔼𝑞^bold-italic-ϵdelimited-[]𝑞^bold-italic-ϵ𝑞^bold-italic-ϵ𝑑2𝜋𝑒𝑑𝑍𝑑𝜏2𝜋𝑍12superscript𝜏2-\mathbb{E}_{q(\hat{\bm{\epsilon}})}[\log q(\hat{\bm{\epsilon}})]=\mathcal{H}(% q(\hat{\bm{\epsilon}}))=d\log(\sqrt{2\pi e})+d\log Z-d\frac{\tau}{\sqrt{2\pi}Z% }\exp(-\frac{1}{2}\tau^{2})- blackboard_E start_POSTSUBSCRIPT italic_q ( over^ start_ARG bold_italic_ϵ end_ARG ) end_POSTSUBSCRIPT [ roman_log italic_q ( over^ start_ARG bold_italic_ϵ end_ARG ) ] = caligraphic_H ( italic_q ( over^ start_ARG bold_italic_ϵ end_ARG ) ) = italic_d roman_log ( square-root start_ARG 2 italic_π italic_e end_ARG ) + italic_d roman_log italic_Z - italic_d divide start_ARG italic_τ end_ARG start_ARG square-root start_ARG 2 italic_π end_ARG italic_Z end_ARG roman_exp ( - divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_τ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (30)

and we finally obtain the exact form of the bound:

logP0(𝒙0)𝔼ϵ^𝒯𝒩(𝟎,𝑰,τ,τ)[logpϵ(𝒙^ϵ)]+d2(1+log(2πσϵ2))+dlogZdτ2πZexp(12τ2)subscript𝑃0subscript𝒙0subscript𝔼similar-to^bold-italic-ϵ𝒯𝒩0𝑰𝜏𝜏delimited-[]subscript𝑝italic-ϵsubscript^𝒙italic-ϵ𝑑212𝜋superscriptsubscript𝜎italic-ϵ2𝑑𝑍𝑑𝜏2𝜋𝑍12superscript𝜏2\log P_{0}(\bm{x}_{0})\geq\mathbb{E}_{\hat{\bm{\epsilon}}\sim\mathcal{T}% \mathcal{N}\left(\bm{0},\bm{I},-\tau,\tau\right)}\left[\log p_{\epsilon}(\hat{% \bm{x}}_{\epsilon})\right]+\frac{d}{2}(1+\log(2\pi\sigma_{\epsilon}^{2}))+d% \log Z-d\frac{\tau}{\sqrt{2\pi}Z}\exp(-\frac{1}{2}\tau^{2})roman_log italic_P start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ≥ blackboard_E start_POSTSUBSCRIPT over^ start_ARG bold_italic_ϵ end_ARG ∼ caligraphic_T caligraphic_N ( bold_0 , bold_italic_I , - italic_τ , italic_τ ) end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) ] + divide start_ARG italic_d end_ARG start_ARG 2 end_ARG ( 1 + roman_log ( 2 italic_π italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) ) + italic_d roman_log italic_Z - italic_d divide start_ARG italic_τ end_ARG start_ARG square-root start_ARG 2 italic_π end_ARG italic_Z end_ARG roman_exp ( - divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_τ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (31)

where the ODE log-likelihood logpϵ(𝒙^ϵ)subscript𝑝italic-ϵsubscript^𝒙italic-ϵ\log p_{\epsilon}(\hat{\bm{x}}_{\epsilon})roman_log italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) can also be evaluated exactly. Similarly, we have the corresponding importance-weighted estimator by modifying Eqn. (29):

logP0(𝒙0)𝔼ϵ^(1),,ϵ^(K)𝒯𝒩(𝟎,𝑰,τ,τ)[log(1Ki=1Kpϵ(𝒙^ϵ(i))q(ϵ^(i)))]+dlogσϵsubscript𝑃0subscript𝒙0subscript𝔼similar-tosuperscript^bold-italic-ϵ1superscript^bold-italic-ϵ𝐾𝒯𝒩0𝑰𝜏𝜏delimited-[]1𝐾superscriptsubscript𝑖1𝐾subscript𝑝italic-ϵsuperscriptsubscript^𝒙italic-ϵ𝑖𝑞superscript^bold-italic-ϵ𝑖𝑑subscript𝜎italic-ϵ\log P_{0}(\bm{x}_{0})\geq\mathbb{E}_{\hat{\bm{\epsilon}}^{(1)},\dots,\hat{\bm% {\epsilon}}^{(K)}\sim\mathcal{T}\mathcal{N}\left(\bm{0},\bm{I},-\tau,\tau% \right)}\left[\log\left(\frac{1}{K}\sum_{i=1}^{K}\frac{p_{\epsilon}(\hat{\bm{x% }}_{\epsilon}^{(i)})}{q(\hat{\bm{\epsilon}}^{(i)})}\right)\right]+d\log\sigma_% {\epsilon}roman_log italic_P start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ≥ blackboard_E start_POSTSUBSCRIPT over^ start_ARG bold_italic_ϵ end_ARG start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT , … , over^ start_ARG bold_italic_ϵ end_ARG start_POSTSUPERSCRIPT ( italic_K ) end_POSTSUPERSCRIPT ∼ caligraphic_T caligraphic_N ( bold_0 , bold_italic_I , - italic_τ , italic_τ ) end_POSTSUBSCRIPT [ roman_log ( divide start_ARG 1 end_ARG start_ARG italic_K end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT divide start_ARG italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) end_ARG start_ARG italic_q ( over^ start_ARG bold_italic_ϵ end_ARG start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) end_ARG ) ] + italic_d roman_log italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT (32)

where 𝒙^ϵ(i)αϵ𝒙0+σϵϵ^(i)superscriptsubscript^𝒙italic-ϵ𝑖subscript𝛼italic-ϵsubscript𝒙0subscript𝜎italic-ϵsuperscript^bold-italic-ϵ𝑖\hat{\bm{x}}_{\epsilon}^{(i)}\coloneqq\alpha_{\epsilon}\bm{x}_{0}+\sigma_{% \epsilon}\hat{\bm{\epsilon}}^{(i)}over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ≔ italic_α start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT over^ start_ARG bold_italic_ϵ end_ARG start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT, and q(ϵ^)𝑞^bold-italic-ϵq(\hat{\bm{\epsilon}})italic_q ( over^ start_ARG bold_italic_ϵ end_ARG ) is expressed in Eqn. (28).

In our experiments, we choose the start time γϵ=13.3subscript𝛾italic-ϵ13.3\gamma_{\epsilon}=-13.3italic_γ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT = - 13.3. Under this setting, we have τ3𝜏3\tau\approx 3italic_τ ≈ 3, and the truncated-normal distribution 𝒯𝒩(𝟎,𝑰,τ,τ)𝒯𝒩0𝑰𝜏𝜏\mathcal{T}\mathcal{N}\left(\bm{0},\bm{I},-\tau,\tau\right)caligraphic_T caligraphic_N ( bold_0 , bold_italic_I , - italic_τ , italic_τ ) is almost the same as the standard normal distribution 𝒩(𝟎,𝑰)𝒩0𝑰\mathcal{N}(\bm{0},\bm{I})caligraphic_N ( bold_0 , bold_italic_I ) due to the 3-σ𝜎\sigmaitalic_σ principle. Thus, 𝒙ϵsubscript𝒙italic-ϵ\bm{x}_{\epsilon}bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT used in training and 𝒙^ϵsubscript^𝒙italic-ϵ\hat{\bm{x}}_{\epsilon}over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT used in testing are virtually identically distributed, resulting in a negligible training-evaluation gap.

A.2 Variational perspective

From the variational perspective, we can view the transition from discrete 𝒙0subscript𝒙0\bm{x}_{0}bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT to continuous 𝒙ϵsubscript𝒙italic-ϵ\bm{x}_{\epsilon}bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT as a variational autoencoder, where the prior pϵ(𝒙ϵ)subscript𝑝italic-ϵsubscript𝒙italic-ϵp_{\epsilon}(\bm{x}_{\epsilon})italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) is modeled by diffusion ODE, and the approximate posterior q0ϵ(𝒙ϵ|𝒙0)subscript𝑞0italic-ϵconditionalsubscript𝒙italic-ϵsubscript𝒙0q_{0\epsilon}(\bm{x}_{\epsilon}|\bm{x}_{0})italic_q start_POSTSUBSCRIPT 0 italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) is the analytical Gaussian transition kernel in the forward diffusion process at the start. We have the variational bound:

logP0(𝒙0)𝔼q0ϵ(𝒙ϵ|𝒙0)[logpϵ0(𝒙0|𝒙ϵ)+logpϵ(𝒙ϵ)logq0ϵ(𝒙ϵ|𝒙0)]subscript𝑃0subscript𝒙0subscript𝔼subscript𝑞0italic-ϵconditionalsubscript𝒙italic-ϵsubscript𝒙0delimited-[]subscript𝑝italic-ϵ0conditionalsubscript𝒙0subscript𝒙italic-ϵsubscript𝑝italic-ϵsubscript𝒙italic-ϵsubscript𝑞0italic-ϵconditionalsubscript𝒙italic-ϵsubscript𝒙0\log P_{0}(\bm{x}_{0})\geq\mathbb{E}_{q_{0\epsilon}(\bm{x}_{\epsilon}|\bm{x}_{% 0})}\left[\log p_{\epsilon 0}(\bm{x}_{0}|\bm{x}_{\epsilon})+\log p_{\epsilon}(% \bm{x}_{\epsilon})-\log q_{0\epsilon}(\bm{x}_{\epsilon}|\bm{x}_{0})\right]roman_log italic_P start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ≥ blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT 0 italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_ϵ 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) + roman_log italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) - roman_log italic_q start_POSTSUBSCRIPT 0 italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ] (33)

where q0ϵ(𝒙ϵ|𝒙0)=𝒩(𝒙ϵ|αϵ𝒙0,σϵ2𝑰)subscript𝑞0italic-ϵconditionalsubscript𝒙italic-ϵsubscript𝒙0𝒩conditionalsubscript𝒙italic-ϵsubscript𝛼italic-ϵsubscript𝒙0superscriptsubscript𝜎italic-ϵ2𝑰q_{0\epsilon}(\bm{x}_{\epsilon}|\bm{x}_{0})=\mathcal{N}(\bm{x}_{\epsilon}|% \alpha_{\epsilon}\bm{x}_{0},\sigma_{\epsilon}^{2}\bm{I})italic_q start_POSTSUBSCRIPT 0 italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = caligraphic_N ( bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT | italic_α start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ). We want to use the reconstruction term pϵ0(𝒙0|𝒙ϵ)subscript𝑝italic-ϵ0conditionalsubscript𝒙0subscript𝒙italic-ϵp_{\epsilon 0}(\bm{x}_{0}|\bm{x}_{\epsilon})italic_p start_POSTSUBSCRIPT italic_ϵ 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) to approximate qϵ0(𝒙0|𝒙ϵ)subscript𝑞italic-ϵ0conditionalsubscript𝒙0subscript𝒙italic-ϵq_{\epsilon 0}(\bm{x}_{0}|\bm{x}_{\epsilon})italic_q start_POSTSUBSCRIPT italic_ϵ 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ). Note that

qϵ0(𝒙0|𝒙ϵ)=q0ϵ(𝒙ϵ|𝒙0)q0(𝒙0)qϵ(𝒙ϵ)subscript𝑞italic-ϵ0conditionalsubscript𝒙0subscript𝒙italic-ϵsubscript𝑞0italic-ϵconditionalsubscript𝒙italic-ϵsubscript𝒙0subscript𝑞0subscript𝒙0subscript𝑞italic-ϵsubscript𝒙italic-ϵq_{\epsilon 0}(\bm{x}_{0}|\bm{x}_{\epsilon})=\frac{q_{0\epsilon}(\bm{x}_{% \epsilon}|\bm{x}_{0})q_{0}(\bm{x}_{0})}{q_{\epsilon}(\bm{x}_{\epsilon})}italic_q start_POSTSUBSCRIPT italic_ϵ 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) = divide start_ARG italic_q start_POSTSUBSCRIPT 0 italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_q start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_ARG start_ARG italic_q start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) end_ARG (34)

for small enough ϵitalic-ϵ\epsilonitalic_ϵ, we have q0(𝒙0)qϵ(𝒙ϵ)subscript𝑞0subscript𝒙0subscript𝑞italic-ϵsubscript𝒙italic-ϵq_{0}(\bm{x}_{0})\approx q_{\epsilon}(\bm{x}_{\epsilon})italic_q start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ≈ italic_q start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ), so qϵ0(𝒙0|𝒙ϵ)q0ϵ(𝒙ϵ|𝒙0)=iq0ϵ(𝒙ϵ,i|𝒙0,i)proportional-tosubscript𝑞italic-ϵ0conditionalsubscript𝒙0subscript𝒙italic-ϵsubscript𝑞0italic-ϵconditionalsubscript𝒙italic-ϵsubscript𝒙0subscriptproduct𝑖subscript𝑞0italic-ϵconditionalsubscript𝒙italic-ϵ𝑖subscript𝒙0𝑖q_{\epsilon 0}(\bm{x}_{0}|\bm{x}_{\epsilon})\propto q_{0\epsilon}(\bm{x}_{% \epsilon}|\bm{x}_{0})=\prod_{i}q_{0\epsilon}(\bm{x}_{\epsilon,i}|\bm{x}_{0,i})italic_q start_POSTSUBSCRIPT italic_ϵ 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) ∝ italic_q start_POSTSUBSCRIPT 0 italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = ∏ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT 0 italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_ϵ , italic_i end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT 0 , italic_i end_POSTSUBSCRIPT ), where i𝑖iitalic_i represents the i𝑖iitalic_i-th dimension. Thus, we also choose pϵ0(𝒙0|𝒙ϵ)subscript𝑝italic-ϵ0conditionalsubscript𝒙0subscript𝒙italic-ϵp_{\epsilon 0}(\bm{x}_{0}|\bm{x}_{\epsilon})italic_p start_POSTSUBSCRIPT italic_ϵ 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) as a factorized distribution, following Kingma et al. (2021):

pϵ0(𝒙0|𝒙ϵ)=ipϵ0(𝒙0,i|𝒙ϵ,i)subscript𝑝italic-ϵ0conditionalsubscript𝒙0subscript𝒙italic-ϵsubscriptproduct𝑖subscript𝑝italic-ϵ0conditionalsubscript𝒙0𝑖subscript𝒙italic-ϵ𝑖p_{\epsilon 0}(\bm{x}_{0}|\bm{x}_{\epsilon})=\prod_{i}p_{\epsilon 0}(\bm{x}_{0% ,i}|\bm{x}_{\epsilon,i})italic_p start_POSTSUBSCRIPT italic_ϵ 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) = ∏ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_ϵ 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 , italic_i end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_ϵ , italic_i end_POSTSUBSCRIPT ) (35)

where each

pϵ0(𝒙0,i|𝒙ϵ,i)q0ϵ(𝒙ϵ,i|𝒙0,i)exp((𝒙ϵ,iαϵ𝒙0,i)22σϵ2)proportional-tosubscript𝑝italic-ϵ0conditionalsubscript𝒙0𝑖subscript𝒙italic-ϵ𝑖subscript𝑞0italic-ϵconditionalsubscript𝒙italic-ϵ𝑖subscript𝒙0𝑖proportional-tosuperscriptsubscript𝒙italic-ϵ𝑖subscript𝛼italic-ϵsubscript𝒙0𝑖22superscriptsubscript𝜎italic-ϵ2p_{\epsilon 0}(\bm{x}_{0,i}|\bm{x}_{\epsilon,i})\propto q_{0\epsilon}(\bm{x}_{% \epsilon,i}|\bm{x}_{0,i})\propto\exp\left(-\frac{\left(\bm{x}_{\epsilon,i}-% \alpha_{\epsilon}\bm{x}_{0,i}\right)^{2}}{2\sigma_{\epsilon}^{2}}\right)italic_p start_POSTSUBSCRIPT italic_ϵ 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 , italic_i end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_ϵ , italic_i end_POSTSUBSCRIPT ) ∝ italic_q start_POSTSUBSCRIPT 0 italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_ϵ , italic_i end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT 0 , italic_i end_POSTSUBSCRIPT ) ∝ roman_exp ( - divide start_ARG ( bold_italic_x start_POSTSUBSCRIPT italic_ϵ , italic_i end_POSTSUBSCRIPT - italic_α start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 , italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ) (36)

As 𝒙0subscript𝒙0\bm{x}_{0}bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is a discrete variable, the probability can be computed by softmax, so we have

logpϵ0(𝒙0|𝒙ϵ)=i=1dlogsoftmaxj=0255((𝒙ϵ,iαϵj)22σϵ2)[𝒙0,i]subscript𝑝italic-ϵ0conditionalsubscript𝒙0subscript𝒙italic-ϵsuperscriptsubscript𝑖1𝑑superscriptsubscriptsoftmax𝑗0255superscriptsubscript𝒙italic-ϵ𝑖subscript𝛼italic-ϵ𝑗22superscriptsubscript𝜎italic-ϵ2delimited-[]subscript𝒙0𝑖\log p_{\epsilon 0}(\bm{x}_{0}|\bm{x}_{\epsilon})=\sum_{i=1}^{d}\log\mathrm{% softmax}_{j=0}^{255}\left(-\frac{\left(\bm{x}_{\epsilon,i}-\alpha_{\epsilon}j% \right)^{2}}{2\sigma_{\epsilon}^{2}}\right)[\bm{x}_{0,i}]roman_log italic_p start_POSTSUBSCRIPT italic_ϵ 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT roman_log roman_softmax start_POSTSUBSCRIPT italic_j = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 255 end_POSTSUPERSCRIPT ( - divide start_ARG ( bold_italic_x start_POSTSUBSCRIPT italic_ϵ , italic_i end_POSTSUBSCRIPT - italic_α start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT italic_j ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ) [ bold_italic_x start_POSTSUBSCRIPT 0 , italic_i end_POSTSUBSCRIPT ] (37)

Besides, the Gaussian entropy term can be computed exactly

𝔼q0ϵ(𝒙ϵ|𝒙0)[logq0ϵ(𝒙ϵ|𝒙0)]=(q0ϵ(𝒙ϵ|𝒙0))=d2(1+log(2πσϵ2))subscript𝔼subscript𝑞0italic-ϵconditionalsubscript𝒙italic-ϵsubscript𝒙0delimited-[]subscript𝑞0italic-ϵconditionalsubscript𝒙italic-ϵsubscript𝒙0subscript𝑞0italic-ϵconditionalsubscript𝒙italic-ϵsubscript𝒙0𝑑212𝜋superscriptsubscript𝜎italic-ϵ2-\mathbb{E}_{q_{0\epsilon}(\bm{x}_{\epsilon}|\bm{x}_{0})}[\log q_{0\epsilon}(% \bm{x}_{\epsilon}|\bm{x}_{0})]=\mathcal{H}(q_{0\epsilon}(\bm{x}_{\epsilon}|\bm% {x}_{0}))=\frac{d}{2}(1+\log(2\pi\sigma_{\epsilon}^{2}))- blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT 0 italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ roman_log italic_q start_POSTSUBSCRIPT 0 italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ] = caligraphic_H ( italic_q start_POSTSUBSCRIPT 0 italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ) = divide start_ARG italic_d end_ARG start_ARG 2 end_ARG ( 1 + roman_log ( 2 italic_π italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) ) (38)

and the bound is reduced to

logP0(𝒙0)𝔼ϵ𝒩(𝟎,𝑰)[logpϵ(𝒙ϵ)+logpϵ0(𝒙0|𝒙ϵ)]+d2(1+log(2πσϵ2))subscript𝑃0subscript𝒙0subscript𝔼similar-tobold-italic-ϵ𝒩0𝑰delimited-[]subscript𝑝italic-ϵsubscript𝒙italic-ϵsubscript𝑝italic-ϵ0conditionalsubscript𝒙0subscript𝒙italic-ϵ𝑑212𝜋superscriptsubscript𝜎italic-ϵ2\log P_{0}(\bm{x}_{0})\geq\mathbb{E}_{\bm{\epsilon}\sim\mathcal{N}(\bm{0},\bm{% I})}\left[\log p_{\epsilon}(\bm{x}_{\epsilon})+\log p_{\epsilon 0}(\bm{x}_{0}|% \bm{x}_{\epsilon})\right]+\frac{d}{2}(1+\log(2\pi\sigma_{\epsilon}^{2}))roman_log italic_P start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ≥ blackboard_E start_POSTSUBSCRIPT bold_italic_ϵ ∼ caligraphic_N ( bold_0 , bold_italic_I ) end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) + roman_log italic_p start_POSTSUBSCRIPT italic_ϵ 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) ] + divide start_ARG italic_d end_ARG start_ARG 2 end_ARG ( 1 + roman_log ( 2 italic_π italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) ) (39)

where 𝒙ϵ=αϵ𝒙0+σϵϵsubscript𝒙italic-ϵsubscript𝛼italic-ϵsubscript𝒙0subscript𝜎italic-ϵbold-italic-ϵ\bm{x}_{\epsilon}=\alpha_{\epsilon}\bm{x}_{0}+\sigma_{\epsilon}\bm{\epsilon}bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT = italic_α start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT bold_italic_ϵ, logpϵ0(𝒙0|𝒙ϵ)subscript𝑝italic-ϵ0conditionalsubscript𝒙0subscript𝒙italic-ϵ\log p_{\epsilon 0}(\bm{x}_{0}|\bm{x}_{\epsilon})roman_log italic_p start_POSTSUBSCRIPT italic_ϵ 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) is given in Eqn. (37) and logpϵ(𝒙ϵ)subscript𝑝italic-ϵsubscript𝒙italic-ϵ\log p_{\epsilon}(\bm{x}_{\epsilon})roman_log italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) is the exact ODE likelihood. We also have the importance weighted estimator by modifying Eqn. (33):

logP0(𝒙0)𝔼ϵ(1),,ϵ(K)𝒩(𝟎,𝑰)[log(1Ki=1Kpϵ(𝒙ϵ)pϵ0(𝒙0|𝒙ϵ)q0ϵ(𝒙ϵ|𝒙0))]subscript𝑃0subscript𝒙0subscript𝔼similar-tosuperscriptbold-italic-ϵ1superscriptbold-italic-ϵ𝐾𝒩0𝑰delimited-[]1𝐾superscriptsubscript𝑖1𝐾subscript𝑝italic-ϵsubscript𝒙italic-ϵsubscript𝑝italic-ϵ0conditionalsubscript𝒙0subscript𝒙italic-ϵsubscript𝑞0italic-ϵconditionalsubscript𝒙italic-ϵsubscript𝒙0\log P_{0}(\bm{x}_{0})\geq\mathbb{E}_{\bm{\epsilon}^{(1)},\dots,\bm{\epsilon}^% {(K)}\sim\mathcal{N}(\bm{0},\bm{I})}\left[\log\left(\frac{1}{K}\sum_{i=1}^{K}% \frac{p_{\epsilon}(\bm{x}_{\epsilon})p_{\epsilon 0}(\bm{x}_{0}|\bm{x}_{% \epsilon})}{q_{0\epsilon}(\bm{x}_{\epsilon}|\bm{x}_{0})}\right)\right]roman_log italic_P start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ≥ blackboard_E start_POSTSUBSCRIPT bold_italic_ϵ start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT , … , bold_italic_ϵ start_POSTSUPERSCRIPT ( italic_K ) end_POSTSUPERSCRIPT ∼ caligraphic_N ( bold_0 , bold_italic_I ) end_POSTSUBSCRIPT [ roman_log ( divide start_ARG 1 end_ARG start_ARG italic_K end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT divide start_ARG italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT italic_ϵ 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) end_ARG start_ARG italic_q start_POSTSUBSCRIPT 0 italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_ARG ) ] (40)

A.3 Practical connections and results

Let us consider the bound without importance weighted estimator. By observing the bound in Eqn. (31) for truncated-normal dequantization and the bound in Eqn. (39) for variational perspective, we can find that they have similar formulations. Suppose we use γϵ=13.3subscript𝛾italic-ϵ13.3\gamma_{\epsilon}=-13.3italic_γ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT = - 13.3, we have τ3.01869,Z0.9974613formulae-sequence𝜏3.01869𝑍0.9974613\tau\approx 3.01869,Z\approx 0.9974613italic_τ ≈ 3.01869 , italic_Z ≈ 0.9974613, and the bound in Eqn. (31) is approximately

logP0(𝒙0)𝔼ϵ^𝒯𝒩(𝟎,𝑰,τ,τ)[logpϵ(𝒙^ϵ)]+d2(1+log(2πσϵ2))0.01522×dsubscript𝑃0subscript𝒙0subscript𝔼similar-to^bold-italic-ϵ𝒯𝒩0𝑰𝜏𝜏delimited-[]subscript𝑝italic-ϵsubscript^𝒙italic-ϵ𝑑212𝜋superscriptsubscript𝜎italic-ϵ20.01522𝑑\log P_{0}(\bm{x}_{0})\geq\mathbb{E}_{\hat{\bm{\epsilon}}\sim\mathcal{T}% \mathcal{N}\left(\bm{0},\bm{I},-\tau,\tau\right)}\left[\log p_{\epsilon}(\hat{% \bm{x}}_{\epsilon})\right]+\frac{d}{2}(1+\log(2\pi\sigma_{\epsilon}^{2}))-0.01% 522\times droman_log italic_P start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ≥ blackboard_E start_POSTSUBSCRIPT over^ start_ARG bold_italic_ϵ end_ARG ∼ caligraphic_T caligraphic_N ( bold_0 , bold_italic_I , - italic_τ , italic_τ ) end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) ] + divide start_ARG italic_d end_ARG start_ARG 2 end_ARG ( 1 + roman_log ( 2 italic_π italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) ) - 0.01522 × italic_d (41)

Next, consider the variational perspective. Though the reconstruction term logpϵ0(𝒙0|𝒙ϵ)subscript𝑝italic-ϵ0conditionalsubscript𝒙0subscript𝒙italic-ϵ\log p_{\epsilon 0}(\bm{x}_{0}|\bm{x}_{\epsilon})roman_log italic_p start_POSTSUBSCRIPT italic_ϵ 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) in Eqn. (39) depends on the data distribution, empirically it is nearly a constant logpϵ0(𝒙0|𝒙ϵ)0.01×dsubscript𝑝italic-ϵ0conditionalsubscript𝒙0subscript𝒙italic-ϵ0.01𝑑\log p_{\epsilon 0}(\bm{x}_{0}|\bm{x}_{\epsilon})\approx-0.01\times droman_log italic_p start_POSTSUBSCRIPT italic_ϵ 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) ≈ - 0.01 × italic_d. So we have the approximate bound

logP0(𝒙0)𝔼ϵ𝒩(𝟎,𝑰)[logpϵ(𝒙ϵ)]+d2(1+log(2πσϵ2))0.01×dsubscript𝑃0subscript𝒙0subscript𝔼similar-tobold-italic-ϵ𝒩0𝑰delimited-[]subscript𝑝italic-ϵsubscript𝒙italic-ϵ𝑑212𝜋superscriptsubscript𝜎italic-ϵ20.01𝑑\log P_{0}(\bm{x}_{0})\geq\mathbb{E}_{\bm{\epsilon}\sim\mathcal{N}(\bm{0},\bm{% I})}\left[\log p_{\epsilon}(\bm{x}_{\epsilon})\right]+\frac{d}{2}(1+\log(2\pi% \sigma_{\epsilon}^{2}))-0.01\times droman_log italic_P start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ≥ blackboard_E start_POSTSUBSCRIPT bold_italic_ϵ ∼ caligraphic_N ( bold_0 , bold_italic_I ) end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) ] + divide start_ARG italic_d end_ARG start_ARG 2 end_ARG ( 1 + roman_log ( 2 italic_π italic_σ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) ) - 0.01 × italic_d (42)

We note the only difference is that our proposed truncated-normal dequantization uses 𝒙^ϵsubscript^𝒙italic-ϵ\hat{\bm{x}}_{\epsilon}over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT rather than 𝒙ϵsubscript𝒙italic-ϵ\bm{x}_{\epsilon}bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT for ODE likelihood evaluation, and there is a small constant difference in the bound.

Table 3: Likelihood results under different bound and number of importance samples K𝐾Kitalic_K. K=1𝐾1K=1italic_K = 1 means we do not use the importance-weighted estimator.
NLL Uniform Dequantization Variational Truncated-Normal Dequantization
K=1𝐾1K=1italic_K = 1 K=5𝐾5K=5italic_K = 5 K=20𝐾20K=20italic_K = 20 K=1𝐾1K=1italic_K = 1 K=5𝐾5K=5italic_K = 5 K=20𝐾20K=20italic_K = 20 K=1𝐾1K=1italic_K = 1 K=5𝐾5K=5italic_K = 5 K=20𝐾20K=20italic_K = 20
CIFAR-10 (VP) 2.74 2.72 2.71 2.60 2.59 2.58 2.60 2.58 2.57
CIFAR-10 (SP) 2.81 2.79 2.78 2.61 2.59 2.58 2.60 2.57 2.56
ImageNet-32 (VP) 3.52 3.51 3.50 3.46 3.44 3.44 3.45 3.44 3.43
ImageNet-32 (SP) 3.57 3.56 3.55 3.48 3.47 3.46 3.47 3.45 3.44
Remark A.1.

For high-dimensional data such as images, directly comparing log-likelihood may suffer from scaling issues by the dimension. In practice, we usually compare the BPD (bits/dim) by

BPD=𝔼𝒙0q0[logP0(𝒙0)dlog2]BPDsubscript𝔼similar-tosubscript𝒙0subscript𝑞0delimited-[]subscript𝑃0subscript𝒙0𝑑2\text{BPD}=\mathbb{E}_{\bm{x}_{0}\sim q_{0}}\left[\frac{-\log P_{0}(\bm{x}_{0}% )}{d\log 2}\right]BPD = blackboard_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_q start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ divide start_ARG - roman_log italic_P start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_ARG start_ARG italic_d roman_log 2 end_ARG ] (43)

where q0subscript𝑞0q_{0}italic_q start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is the data distribution. Since BPD averages the log-likelihood on each dimension, scaling dimensionality has no effect on the final result.

We test the two types of dequantization and the variational perspective on our final models, using different numbers of importance samples K𝐾Kitalic_K. The results are listed in Table 3. Empirically, truncated-normal dequantization performs slightly better than variational, while uniform dequantization gives a bad likelihood due to the large training-evaluation gap. We also observe that increasing K𝐾Kitalic_K further improves the results by giving a tighter bound.

Refer to caption
Figure 4: The likelihood evaluation results under uniform dequantization for different start times γϵsubscript𝛾italic-ϵ\gamma_{\epsilon}italic_γ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT. To plot the curve, we estimate the likelihood using the first 1024 test samples for CIFAR-10, and the first 512 test samples for ImageNet-32.
Remark A.2.

Since uniform dequantized data has a larger noise level than truncated-normal dequantized data, we find evaluating logpϵ(𝒙0+𝒖)subscript𝑝italic-ϵsubscript𝒙0𝒖\log p_{\epsilon}(\bm{x}_{0}+\bm{u})roman_log italic_p start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + bold_italic_u ) at start time γϵ=13.3subscript𝛾italic-ϵ13.3\gamma_{\epsilon}=-13.3italic_γ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT = - 13.3 leads to bad likelihood. Thus, we tune γϵsubscript𝛾italic-ϵ\gamma_{\epsilon}italic_γ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT for uniform dequantization (Figure 4), and eventually choose γϵ=12.0,11.9,11.7,11.6subscript𝛾italic-ϵ12.011.911.711.6\gamma_{\epsilon}=-12.0,-11.9,-11.7,-11.6italic_γ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT = - 12.0 , - 11.9 , - 11.7 , - 11.6 for CIFAR-10 (VP), CIFAR-10 (SP), ImageNet-32 (VP), ImageNet-32 (SP) respectively.

Appendix B Equivalence of different predictors and matching objectives

We have the following theorem which demonstrates that different predictors are mutually transformable by a time-dependent skip connection, and they can be trained in a simulation-free approach by equivalent matching objectives.

Theorem B.1.

Let 𝐱0subscript𝐱0\bm{x}_{0}bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT be the sample from data distribution, and ϵbold-ϵ\bm{\epsilon}bold_italic_ϵ be the sample from 𝒩(𝟎,𝐈)𝒩0𝐈\mathcal{N}(\bm{0},\bm{I})caligraphic_N ( bold_0 , bold_italic_I ). Denote 𝐱t=αt𝐱0+σtϵ,𝐯=α˙t𝐱0+σ˙tϵformulae-sequencesubscript𝐱𝑡subscript𝛼𝑡subscript𝐱0subscript𝜎𝑡bold-ϵ𝐯subscriptnormal-˙𝛼𝑡subscript𝐱0subscriptnormal-˙𝜎𝑡bold-ϵ\bm{x}_{t}=\alpha_{t}\bm{x}_{0}+\sigma_{t}\bm{\epsilon},\bm{v}=\dot{\alpha}_{t% }\bm{x}_{0}+\dot{\sigma}_{t}\bm{\epsilon}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_ϵ , bold_italic_v = over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_ϵ. Suppose we have four kinds of predictors parameterized by θ𝜃\thetaitalic_θ and corresponding matching objectives with positive time weighting function w(t)𝑤𝑡w(t)italic_w ( italic_t ):

  • score predictor 𝒔θ(𝒙t,t)subscript𝒔𝜃subscript𝒙𝑡𝑡\bm{s}_{\theta}(\bm{x}_{t},t)bold_italic_s start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) and score matching loss 𝒥𝑆𝑀(θ,w(t))=𝔼t[w(t)𝔼𝒙0,ϵ[𝒔θ(𝒙t,t)𝒙logqt(𝒙t)22]]subscript𝒥𝑆𝑀𝜃𝑤𝑡subscript𝔼𝑡delimited-[]𝑤𝑡subscript𝔼subscript𝒙0bold-italic-ϵdelimited-[]superscriptsubscriptnormsubscript𝒔𝜃subscript𝒙𝑡𝑡subscript𝒙subscript𝑞𝑡subscript𝒙𝑡22\mathcal{J}_{\text{SM}}(\theta,w(t))=\mathbb{E}_{t}\left[w(t)\mathbb{E}_{\bm{x% }_{0},\bm{\epsilon}}[\|\bm{s}_{\theta}(\bm{x}_{t},t)-\nabla_{\bm{x}}\log q_{t}% (\bm{x}_{t})\|_{2}^{2}]\right]caligraphic_J start_POSTSUBSCRIPT SM end_POSTSUBSCRIPT ( italic_θ , italic_w ( italic_t ) ) = blackboard_E start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT [ italic_w ( italic_t ) blackboard_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ end_POSTSUBSCRIPT [ ∥ bold_italic_s start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - ∇ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT roman_log italic_q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] ]

  • noise predictor ϵθ(𝒙t,t)subscriptbold-italic-ϵ𝜃subscript𝒙𝑡𝑡\bm{\epsilon}_{\theta}(\bm{x}_{t},t)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) and noise matching loss 𝒥𝑁𝑀(θ,w(t))=𝔼t[w(t)𝔼𝒙0,ϵ[ϵθ(𝒙t,t)ϵ22]]subscript𝒥𝑁𝑀𝜃𝑤𝑡subscript𝔼𝑡delimited-[]𝑤𝑡subscript𝔼subscript𝒙0bold-italic-ϵdelimited-[]superscriptsubscriptnormsubscriptbold-italic-ϵ𝜃subscript𝒙𝑡𝑡bold-italic-ϵ22\mathcal{J}_{\text{NM}}(\theta,w(t))=\mathbb{E}_{t}\left[w(t)\mathbb{E}_{\bm{x% }_{0},\bm{\epsilon}}[\|\bm{\epsilon}_{\theta}(\bm{x}_{t},t)-\bm{\epsilon}\|_{2% }^{2}]\right]caligraphic_J start_POSTSUBSCRIPT NM end_POSTSUBSCRIPT ( italic_θ , italic_w ( italic_t ) ) = blackboard_E start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT [ italic_w ( italic_t ) blackboard_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ end_POSTSUBSCRIPT [ ∥ bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - bold_italic_ϵ ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] ]

  • data predictor 𝒙θ(𝒙t,t)subscript𝒙𝜃subscript𝒙𝑡𝑡\bm{x}_{\theta}(\bm{x}_{t},t)bold_italic_x start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) and data matching loss 𝒥𝐷𝑀(θ,w(t))=𝔼t[w(t)𝔼𝒙0,ϵ[𝒙θ(𝒙t,t)𝒙022]]subscript𝒥𝐷𝑀𝜃𝑤𝑡subscript𝔼𝑡delimited-[]𝑤𝑡subscript𝔼subscript𝒙0bold-italic-ϵdelimited-[]superscriptsubscriptnormsubscript𝒙𝜃subscript𝒙𝑡𝑡subscript𝒙022\mathcal{J}_{\text{DM}}(\theta,w(t))=\mathbb{E}_{t}\left[w(t)\mathbb{E}_{\bm{x% }_{0},\bm{\epsilon}}[\|\bm{x}_{\theta}(\bm{x}_{t},t)-\bm{x}_{0}\|_{2}^{2}]\right]caligraphic_J start_POSTSUBSCRIPT DM end_POSTSUBSCRIPT ( italic_θ , italic_w ( italic_t ) ) = blackboard_E start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT [ italic_w ( italic_t ) blackboard_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ end_POSTSUBSCRIPT [ ∥ bold_italic_x start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] ]

  • velocity predictor 𝒗θ(𝒙t,t)subscript𝒗𝜃subscript𝒙𝑡𝑡\bm{v}_{\theta}(\bm{x}_{t},t)bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) and flow matching loss 𝒥𝐹𝑀(θ,w(t))=𝔼t[w(t)𝔼𝒙0,ϵ[𝒗θ(𝒙t,t)𝒗22]]subscript𝒥𝐹𝑀𝜃𝑤𝑡subscript𝔼𝑡delimited-[]𝑤𝑡subscript𝔼subscript𝒙0bold-italic-ϵdelimited-[]superscriptsubscriptnormsubscript𝒗𝜃subscript𝒙𝑡𝑡𝒗22\mathcal{J}_{\text{FM}}(\theta,w(t))=\mathbb{E}_{t}\left[w(t)\mathbb{E}_{\bm{x% }_{0},\bm{\epsilon}}[\|\bm{v}_{\theta}(\bm{x}_{t},t)-\bm{v}\|_{2}^{2}]\right]caligraphic_J start_POSTSUBSCRIPT FM end_POSTSUBSCRIPT ( italic_θ , italic_w ( italic_t ) ) = blackboard_E start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT [ italic_w ( italic_t ) blackboard_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ end_POSTSUBSCRIPT [ ∥ bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - bold_italic_v ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] ]

For any w(t)𝑤𝑡w(t)italic_w ( italic_t ), if we denote the optimal (ground-truth) predictors that minimize the corresponding matching losses as 𝐬*(𝐱t,t),ϵ*(𝐱t,t),𝐱*(𝐱t,t),𝐯*(𝐱t,t)superscript𝐬subscript𝐱𝑡𝑡superscriptbold-ϵsubscript𝐱𝑡𝑡superscript𝐱subscript𝐱𝑡𝑡superscript𝐯subscript𝐱𝑡𝑡\bm{s}^{*}(\bm{x}_{t},t),\bm{\epsilon}^{*}(\bm{x}_{t},t),\bm{x}^{*}(\bm{x}_{t}% ,t),\bm{v}^{*}(\bm{x}_{t},t)bold_italic_s start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) , bold_italic_ϵ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) , bold_italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) , bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) respectively, then they are equivalent by the following relations:

ϵ*(𝒙t,t)superscriptbold-italic-ϵsubscript𝒙𝑡𝑡\displaystyle\bm{\epsilon}^{*}(\bm{x}_{t},t)bold_italic_ϵ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) =σt𝒔*(𝒙t,t)absentsubscript𝜎𝑡superscript𝒔subscript𝒙𝑡𝑡\displaystyle=-\sigma_{t}\bm{s}^{*}(\bm{x}_{t},t)= - italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_s start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) (44)
𝒙*(𝒙t,t)superscript𝒙subscript𝒙𝑡𝑡\displaystyle\bm{x}^{*}(\bm{x}_{t},t)bold_italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) =1αt𝒙t+σt2αt𝒔*(𝒙t,t)absent1subscript𝛼𝑡subscript𝒙𝑡superscriptsubscript𝜎𝑡2subscript𝛼𝑡superscript𝒔subscript𝒙𝑡𝑡\displaystyle=\frac{1}{\alpha_{t}}\bm{x}_{t}+\frac{\sigma_{t}^{2}}{\alpha_{t}}% \bm{s}^{*}(\bm{x}_{t},t)= divide start_ARG 1 end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + divide start_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG bold_italic_s start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t )
𝒗*(𝒙t,t)superscript𝒗subscript𝒙𝑡𝑡\displaystyle\bm{v}^{*}(\bm{x}_{t},t)bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) =f(t)𝒙t12g2(t)𝒔*(𝒙t,t)absent𝑓𝑡subscript𝒙𝑡12superscript𝑔2𝑡superscript𝒔subscript𝒙𝑡𝑡\displaystyle=f(t)\bm{x}_{t}-\frac{1}{2}g^{2}(t)\bm{s}^{*}(\bm{x}_{t},t)= italic_f ( italic_t ) bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) bold_italic_s start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t )

where 𝐬*(𝐱t,t)=𝐱logqt(𝐱t)superscript𝐬subscript𝐱𝑡𝑡subscriptnormal-∇𝐱subscript𝑞𝑡subscript𝐱𝑡\bm{s}^{*}(\bm{x}_{t},t)=\nabla_{\bm{x}}\log q_{t}(\bm{x}_{t})bold_italic_s start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) = ∇ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT roman_log italic_q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) is the ground-truth score.

Proof.

For any positive weighting w(t)𝑤𝑡w(t)italic_w ( italic_t ), the overall optimum of the matching loss 𝔼t[w(t)𝔼𝒙0,ϵ[22]]\mathbb{E}_{t}\left[w(t)\mathbb{E}_{\bm{x}_{0},\bm{\epsilon}}[\|\cdot\|_{2}^{2% }]\right]blackboard_E start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT [ italic_w ( italic_t ) blackboard_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ end_POSTSUBSCRIPT [ ∥ ⋅ ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] ] is achieved when the optimum of the inner expectation 𝔼𝒙0,ϵ[22]\mathbb{E}_{\bm{x}_{0},\bm{\epsilon}}[\|\cdot\|_{2}^{2}]blackboard_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ end_POSTSUBSCRIPT [ ∥ ⋅ ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] is achieved for any t𝑡titalic_t. For fixed t𝑡titalic_t, by denoising score matching (Vincent, 2011), we know minimizing 𝔼𝒙0,ϵ[𝒔θ(𝒙t,t)𝒙logqt(𝒙t)22]subscript𝔼subscript𝒙0bold-italic-ϵdelimited-[]superscriptsubscriptnormsubscript𝒔𝜃subscript𝒙𝑡𝑡subscript𝒙subscript𝑞𝑡subscript𝒙𝑡22\mathbb{E}_{\bm{x}_{0},\bm{\epsilon}}[\|\bm{s}_{\theta}(\bm{x}_{t},t)-\nabla_{% \bm{x}}\log q_{t}(\bm{x}_{t})\|_{2}^{2}]blackboard_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ end_POSTSUBSCRIPT [ ∥ bold_italic_s start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - ∇ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT roman_log italic_q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] is equivalent to minimizing 𝔼𝒙0,ϵ[𝒔θ(𝒙t,t)𝒙logq0t(𝒙t|𝒙0)22]=𝔼q(𝒙t)𝔼qt0(𝒙0|𝒙t)[𝒔θ(𝒙t,t)𝒙logq0t(𝒙t|𝒙0)22]\mathbb{E}_{\bm{x}_{0},\bm{\epsilon}}[\|\bm{s}_{\theta}(\bm{x}_{t},t)-\nabla_{% \bm{x}}\log q_{0t}(\bm{x}_{t}|\bm{x}_{0})\|_{2}^{2}]=\mathbb{E}_{q(\bm{x}_{t})% }\mathbb{E}_{q_{t0}(\bm{x}_{0}|\bm{x}_{t})}[\|\bm{s}_{\theta}(\bm{x}_{t},t)-% \nabla_{\bm{x}}\log q_{0t}(\bm{x}_{t}|\bm{x}_{0})\|_{2}^{2}]blackboard_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ end_POSTSUBSCRIPT [ ∥ bold_italic_s start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - ∇ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT roman_log italic_q start_POSTSUBSCRIPT 0 italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] = blackboard_E start_POSTSUBSCRIPT italic_q ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_t 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ ∥ bold_italic_s start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - ∇ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT roman_log italic_q start_POSTSUBSCRIPT 0 italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ], where logq0t(𝒙t|𝒙0)=ϵσtsubscript𝑞0𝑡conditionalsubscript𝒙𝑡subscript𝒙0bold-italic-ϵsubscript𝜎𝑡\log q_{0t}(\bm{x}_{t}|\bm{x}_{0})=-\frac{\bm{\epsilon}}{\sigma_{t}}roman_log italic_q start_POSTSUBSCRIPT 0 italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = - divide start_ARG bold_italic_ϵ end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG. The inner expectation is a minimum mean square error problem, so the optimal score predictor satisfies

𝒔*(𝒙t,t)=𝔼qt0(𝒙0|𝒙t)[𝒙logq0t(𝒙t|𝒙0)]=1σt𝔼qt0(𝒙0|𝒙t)[ϵ]superscript𝒔subscript𝒙𝑡𝑡subscript𝔼subscript𝑞𝑡0conditionalsubscript𝒙0subscript𝒙𝑡delimited-[]subscript𝒙subscript𝑞0𝑡conditionalsubscript𝒙𝑡subscript𝒙01subscript𝜎𝑡subscript𝔼subscript𝑞𝑡0conditionalsubscript𝒙0subscript𝒙𝑡delimited-[]bold-italic-ϵ\bm{s}^{*}(\bm{x}_{t},t)=\mathbb{E}_{q_{t0}(\bm{x}_{0}|\bm{x}_{t})}[\nabla_{% \bm{x}}\log q_{0t}(\bm{x}_{t}|\bm{x}_{0})]=-\frac{1}{\sigma_{t}}\mathbb{E}_{q_% {t0}(\bm{x}_{0}|\bm{x}_{t})}[\bm{\epsilon}]bold_italic_s start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) = blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_t 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT roman_log italic_q start_POSTSUBSCRIPT 0 italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ] = - divide start_ARG 1 end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_t 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ bold_italic_ϵ ] (45)

Similarly, for 𝒥NM(θ,w(t))subscript𝒥NM𝜃𝑤𝑡\mathcal{J}_{\text{NM}}(\theta,w(t))caligraphic_J start_POSTSUBSCRIPT NM end_POSTSUBSCRIPT ( italic_θ , italic_w ( italic_t ) ), the optimal noise predictor satisfies

ϵ*(𝒙t,t)=𝔼qt0(𝒙0|𝒙t)[ϵ]=σt𝒔*(𝒙t,t)superscriptbold-italic-ϵsubscript𝒙𝑡𝑡subscript𝔼subscript𝑞𝑡0conditionalsubscript𝒙0subscript𝒙𝑡delimited-[]bold-italic-ϵsubscript𝜎𝑡superscript𝒔subscript𝒙𝑡𝑡\bm{\epsilon}^{*}(\bm{x}_{t},t)=\mathbb{E}_{q_{t0}(\bm{x}_{0}|\bm{x}_{t})}[\bm% {\epsilon}]=-\sigma_{t}\bm{s}^{*}(\bm{x}_{t},t)bold_italic_ϵ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) = blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_t 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ bold_italic_ϵ ] = - italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_s start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) (46)

For 𝒥DM(θ,w(t))subscript𝒥DM𝜃𝑤𝑡\mathcal{J}_{\text{DM}}(\theta,w(t))caligraphic_J start_POSTSUBSCRIPT DM end_POSTSUBSCRIPT ( italic_θ , italic_w ( italic_t ) ), the optimal data predictor satisfies

𝒙*(𝒙t,t)superscript𝒙subscript𝒙𝑡𝑡\displaystyle\bm{x}^{*}(\bm{x}_{t},t)bold_italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) =𝔼qt0(𝒙0|𝒙t)[𝒙0]absentsubscript𝔼subscript𝑞𝑡0conditionalsubscript𝒙0subscript𝒙𝑡delimited-[]subscript𝒙0\displaystyle=\mathbb{E}_{q_{t0}(\bm{x}_{0}|\bm{x}_{t})}[\bm{x}_{0}]= blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_t 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ] (47)
=𝔼qt0(𝒙0|𝒙t)[𝒙tσtϵαt]absentsubscript𝔼subscript𝑞𝑡0conditionalsubscript𝒙0subscript𝒙𝑡delimited-[]subscript𝒙𝑡subscript𝜎𝑡bold-italic-ϵsubscript𝛼𝑡\displaystyle=\mathbb{E}_{q_{t0}(\bm{x}_{0}|\bm{x}_{t})}\left[\frac{\bm{x}_{t}% -\sigma_{t}\bm{\epsilon}}{\alpha_{t}}\right]= blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_t 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ divide start_ARG bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_ϵ end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG ]
=1αt𝒙tσtαt𝔼qt0(𝒙0|𝒙t)[ϵ]absent1subscript𝛼𝑡subscript𝒙𝑡subscript𝜎𝑡subscript𝛼𝑡subscript𝔼subscript𝑞𝑡0conditionalsubscript𝒙0subscript𝒙𝑡delimited-[]bold-italic-ϵ\displaystyle=\frac{1}{\alpha_{t}}\bm{x}_{t}-\frac{\sigma_{t}}{\alpha_{t}}% \mathbb{E}_{q_{t0}(\bm{x}_{0}|\bm{x}_{t})}[\bm{\epsilon}]= divide start_ARG 1 end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - divide start_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_t 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ bold_italic_ϵ ]
=1αt𝒙t+σt2αt𝒔*(𝒙t,t)absent1subscript𝛼𝑡subscript𝒙𝑡superscriptsubscript𝜎𝑡2subscript𝛼𝑡superscript𝒔subscript𝒙𝑡𝑡\displaystyle=\frac{1}{\alpha_{t}}\bm{x}_{t}+\frac{\sigma_{t}^{2}}{\alpha_{t}}% \bm{s}^{*}(\bm{x}_{t},t)= divide start_ARG 1 end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + divide start_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG bold_italic_s start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t )

For 𝒥FM(θ,w(t))subscript𝒥FM𝜃𝑤𝑡\mathcal{J}_{\text{FM}}(\theta,w(t))caligraphic_J start_POSTSUBSCRIPT FM end_POSTSUBSCRIPT ( italic_θ , italic_w ( italic_t ) ), the optimal velocity predictor satisfies

𝒗*(𝒙t,t)superscript𝒗subscript𝒙𝑡𝑡\displaystyle\bm{v}^{*}(\bm{x}_{t},t)bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) =𝔼qt0(𝒙0|𝒙t)[α˙t𝒙0+σ˙tϵ]absentsubscript𝔼subscript𝑞𝑡0conditionalsubscript𝒙0subscript𝒙𝑡delimited-[]subscript˙𝛼𝑡subscript𝒙0subscript˙𝜎𝑡bold-italic-ϵ\displaystyle=\mathbb{E}_{q_{t0}(\bm{x}_{0}|\bm{x}_{t})}[\dot{\alpha}_{t}\bm{x% }_{0}+\dot{\sigma}_{t}\bm{\epsilon}]= blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_t 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_ϵ ] (48)
=α˙t𝔼qt0(𝒙0|𝒙t)[𝒙0]+σ˙t𝔼qt0(𝒙0|𝒙t)[ϵ]absentsubscript˙𝛼𝑡subscript𝔼subscript𝑞𝑡0conditionalsubscript𝒙0subscript𝒙𝑡delimited-[]subscript𝒙0subscript˙𝜎𝑡subscript𝔼subscript𝑞𝑡0conditionalsubscript𝒙0subscript𝒙𝑡delimited-[]bold-italic-ϵ\displaystyle=\dot{\alpha}_{t}\mathbb{E}_{q_{t0}(\bm{x}_{0}|\bm{x}_{t})}[\bm{x% }_{0}]+\dot{\sigma}_{t}\mathbb{E}_{q_{t0}(\bm{x}_{0}|\bm{x}_{t})}[\bm{\epsilon}]= over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_t 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ] + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_t 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ bold_italic_ϵ ]
=α˙tαt𝒙t+(α˙tαtσt2σtσ˙t)𝒔*(𝒙t,t)absentsubscript˙𝛼𝑡subscript𝛼𝑡subscript𝒙𝑡subscript˙𝛼𝑡subscript𝛼𝑡superscriptsubscript𝜎𝑡2subscript𝜎𝑡subscript˙𝜎𝑡superscript𝒔subscript𝒙𝑡𝑡\displaystyle=\frac{\dot{\alpha}_{t}}{\alpha_{t}}\bm{x}_{t}+\left(\frac{\dot{% \alpha}_{t}}{\alpha_{t}}\sigma_{t}^{2}-\sigma_{t}\dot{\sigma}_{t}\right)\bm{s}% ^{*}(\bm{x}_{t},t)= divide start_ARG over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + ( divide start_ARG over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) bold_italic_s start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t )
=f(t)𝒙t12g2(t)𝒔*(𝒙t,t)absent𝑓𝑡subscript𝒙𝑡12superscript𝑔2𝑡superscript𝒔subscript𝒙𝑡𝑡\displaystyle=f(t)\bm{x}_{t}-\frac{1}{2}g^{2}(t)\bm{s}^{*}(\bm{x}_{t},t)= italic_f ( italic_t ) bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) bold_italic_s start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t )

The equivalence of optimal predictors also implies the equivalence of parameterized predictors. From the above theorem, we know 𝒗θ(𝒙t,t)subscript𝒗𝜃subscript𝒙𝑡𝑡\bm{v}_{\theta}(\bm{x}_{t},t)bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) and ϵθ(𝒙t,t)subscriptbold-italic-ϵ𝜃subscript𝒙𝑡𝑡\bm{\epsilon}_{\theta}(\bm{x}_{t},t)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) are related by 𝒗θ(𝒙t,t)=f(t)𝒙t+g2(t)2σtϵθ(𝒙t,t)subscript𝒗𝜃subscript𝒙𝑡𝑡𝑓𝑡subscript𝒙𝑡superscript𝑔2𝑡2subscript𝜎𝑡subscriptbold-italic-ϵ𝜃subscript𝒙𝑡𝑡\bm{v}_{\theta}(\bm{x}_{t},t)=f(t)\bm{x}_{t}+\frac{g^{2}(t)}{2\sigma_{t}}\bm{% \epsilon}_{\theta}(\bm{x}_{t},t)bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) = italic_f ( italic_t ) bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + divide start_ARG italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) end_ARG start_ARG 2 italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ). In practice, we use γ𝛾\gammaitalic_γ timing. From the relationship 𝒗θ(𝒙t,t)=𝒗θ(𝒙γ,γ)dγdt,ϵθ(𝒙t,t)=ϵθ(𝒙γ,γ)formulae-sequencesubscript𝒗𝜃subscript𝒙𝑡𝑡subscript𝒗𝜃subscript𝒙𝛾𝛾d𝛾d𝑡subscriptbold-italic-ϵ𝜃subscript𝒙𝑡𝑡subscriptbold-italic-ϵ𝜃subscript𝒙𝛾𝛾\bm{v}_{\theta}(\bm{x}_{t},t)=\bm{v}_{\theta}(\bm{x}_{\gamma},\gamma)\frac{% \mathrm{d}\gamma}{\mathrm{d}t},\bm{\epsilon}_{\theta}(\bm{x}_{t},t)=\bm{% \epsilon}_{\theta}(\bm{x}_{\gamma},\gamma)bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) = bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ) divide start_ARG roman_d italic_γ end_ARG start_ARG roman_d italic_t end_ARG , bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) = bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ), we obtain the noise predictor expressed by 𝒗θ(𝒙γ,γ)subscript𝒗𝜃subscript𝒙𝛾𝛾\bm{v}_{\theta}(\bm{x}_{\gamma},\gamma)bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ )

ϵθ(𝒙γ,γ)=2𝒗θ(𝒙γ,γ)α˙γαγ𝒙γσγsubscriptbold-italic-ϵ𝜃subscript𝒙𝛾𝛾2subscript𝒗𝜃subscript𝒙𝛾𝛾subscript˙𝛼𝛾subscript𝛼𝛾subscript𝒙𝛾subscript𝜎𝛾\bm{\epsilon}_{\theta}(\bm{x}_{\gamma},\gamma)=2\frac{\bm{v}_{\theta}(\bm{x}_{% \gamma},\gamma)-\frac{\dot{\alpha}_{\gamma}}{\alpha_{\gamma}}\bm{x}_{\gamma}}{% \sigma_{\gamma}}bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ) = 2 divide start_ARG bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ) - divide start_ARG over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT end_ARG bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT end_ARG (49)

Further, we can replace 𝒗θ(𝒙γ,γ)subscript𝒗𝜃subscript𝒙𝛾𝛾\bm{v}_{\theta}(\bm{x}_{\gamma},\gamma)bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ) with the normalized velocity predictor 𝒗~θ(𝒙γ,γ)=𝒗θ(𝒙γ,γ)/α˙γ2+σ˙γ2subscript~𝒗𝜃subscript𝒙𝛾𝛾subscript𝒗𝜃subscript𝒙𝛾𝛾superscriptsubscript˙𝛼𝛾2superscriptsubscript˙𝜎𝛾2\tilde{\bm{v}}_{\theta}(\bm{x}_{\gamma},\gamma)=\bm{v}_{\theta}(\bm{x}_{\gamma% },\gamma)/\sqrt{\dot{\alpha}_{\gamma}^{2}+\dot{\sigma}_{\gamma}^{2}}over~ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ) = bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ) / square-root start_ARG over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG.

Moreover, we can derive the equivalent training objectives under different parameterizations by employing the relations discussed above freely. For example, when we replace the normalized velocity predictor 𝒗~θsubscript~𝒗𝜃\tilde{\bm{v}}_{\theta}over~ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT with the score predictor 𝒔θsubscript𝒔𝜃\bm{s}_{\theta}bold_italic_s start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT in the second-order objective Eqn. (4.3), we can obtain the second-order denoising score matching similar to Lu et al. (2022a). However, though theoretically equivalent, the actual performance of these objectives highly depends on the specific model architecture, hyperparameters and parameterization, and the authors of Lu et al. (2022a) find that their high-order denoising score matching objectives only work for VE schedule, but degenerate the performance of pretrained models with VP schedule.

Appendix C Specifications under VP and SP schedule

As stated in Section 4.3, using γ𝛾\gammaitalic_γ timing and normalized velocity predictor 𝒗~θsubscript~𝒗𝜃\tilde{\bm{v}}_{\theta}over~ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT, the likelihood weighted first-order and second-order flow matching objectives are reformulated as:

𝒥FM=γ0γT2α˙γ2+σ˙γ2σγ2𝔼𝒙0,ϵ𝒗~θ(𝒙γ,γ)𝒗~22dγsubscript𝒥FMsuperscriptsubscriptsubscript𝛾0subscript𝛾𝑇2superscriptsubscript˙𝛼𝛾2superscriptsubscript˙𝜎𝛾2superscriptsubscript𝜎𝛾2subscript𝔼subscript𝒙0bold-italic-ϵsuperscriptsubscriptnormsubscript~𝒗𝜃subscript𝒙𝛾𝛾~𝒗22differential-d𝛾\mathcal{J}_{\text{FM}}=\int_{\gamma_{0}}^{\gamma_{T}}2\frac{\dot{\alpha}_{% \gamma}^{2}+\dot{\sigma}_{\gamma}^{2}}{\sigma_{\gamma}^{2}}\mathbb{E}_{\bm{x}_% {0},\bm{\epsilon}}\|\tilde{\bm{v}}_{\theta}(\bm{x}_{\gamma},\gamma)-\tilde{\bm% {v}}\|_{2}^{2}\mathrm{d}\gammacaligraphic_J start_POSTSUBSCRIPT FM end_POSTSUBSCRIPT = ∫ start_POSTSUBSCRIPT italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_γ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT end_POSTSUPERSCRIPT 2 divide start_ARG over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG blackboard_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ end_POSTSUBSCRIPT ∥ over~ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ) - over~ start_ARG bold_italic_v end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_d italic_γ (50)
𝒥FM,tr=γ0γT2α˙γ2+σ˙γ2σγ2𝔼𝒙0,ϵ(σγtr(𝒗~θ)σ˙γα˙γ2+σ˙γ2d+2α˙γ2+σ˙γ2σγ𝒗~θ(𝒙γ,γ)𝒗~22)2dγsubscript𝒥FMtrsuperscriptsubscriptsubscript𝛾0subscript𝛾𝑇2superscriptsubscript˙𝛼𝛾2superscriptsubscript˙𝜎𝛾2superscriptsubscript𝜎𝛾2subscript𝔼subscript𝒙0bold-italic-ϵsuperscriptsubscript𝜎𝛾trsubscript~𝒗𝜃subscript˙𝜎𝛾superscriptsubscript˙𝛼𝛾2superscriptsubscript˙𝜎𝛾2𝑑2superscriptsubscript˙𝛼𝛾2superscriptsubscript˙𝜎𝛾2subscript𝜎𝛾superscriptsubscriptnormsubscript~𝒗𝜃subscript𝒙𝛾𝛾~𝒗222differential-d𝛾\mathcal{J}_{\text{FM},\mathrm{tr}}=\int_{\gamma_{0}}^{\gamma_{T}}2\frac{\dot{% \alpha}_{\gamma}^{2}+\dot{\sigma}_{\gamma}^{2}}{\sigma_{\gamma}^{2}}\mathbb{E}% _{\bm{x}_{0},\bm{\epsilon}}\left(\sigma_{\gamma}\mathrm{tr}(\nabla\tilde{\bm{v% }}_{\theta})-\frac{\dot{\sigma}_{\gamma}}{\sqrt{\dot{\alpha}_{\gamma}^{2}+\dot% {\sigma}_{\gamma}^{2}}}d+\frac{2\sqrt{\dot{\alpha}_{\gamma}^{2}+\dot{\sigma}_{% \gamma}^{2}}}{\sigma_{\gamma}}\|\tilde{\bm{v}}_{\theta}(\bm{x}_{\gamma},\gamma% )-\tilde{\bm{v}}\|_{2}^{2}\right)^{2}\mathrm{d}\gammacaligraphic_J start_POSTSUBSCRIPT FM , roman_tr end_POSTSUBSCRIPT = ∫ start_POSTSUBSCRIPT italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_γ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT end_POSTSUPERSCRIPT 2 divide start_ARG over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG blackboard_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ end_POSTSUBSCRIPT ( italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT roman_tr ( ∇ over~ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ) - divide start_ARG over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT end_ARG start_ARG square-root start_ARG over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG end_ARG italic_d + divide start_ARG 2 square-root start_ARG over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT end_ARG ∥ over~ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ) - over~ start_ARG bold_italic_v end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_d italic_γ (51)

where 𝒗=α˙γ𝒙0+σ˙γϵ,𝒗~=𝒗/α˙γ2+σ˙γ2formulae-sequence𝒗subscript˙𝛼𝛾subscript𝒙0subscript˙𝜎𝛾bold-italic-ϵ~𝒗𝒗superscriptsubscript˙𝛼𝛾2superscriptsubscript˙𝜎𝛾2\bm{v}=\dot{\alpha}_{\gamma}\bm{x}_{0}+\dot{\sigma}_{\gamma}\bm{\epsilon},% \tilde{\bm{v}}=\bm{v}/\sqrt{\dot{\alpha}_{\gamma}^{2}+\dot{\sigma}_{\gamma}^{2}}bold_italic_v = over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT bold_italic_ϵ , over~ start_ARG bold_italic_v end_ARG = bold_italic_v / square-root start_ARG over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG. For VP and SP schedule, since γ=log(σγ2/αγ2)𝛾superscriptsubscript𝜎𝛾2superscriptsubscript𝛼𝛾2\gamma=\log(\sigma_{\gamma}^{2}/\alpha_{\gamma}^{2})italic_γ = roman_log ( italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ), using their schedule properties, αγ,σγsubscript𝛼𝛾subscript𝜎𝛾\alpha_{\gamma},\sigma_{\gamma}italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT are deterministic functions of γ𝛾\gammaitalic_γ without any hyperparameters. Thus, we can derive their specific objectives and equivalent predictors using the formula for general noise schedules. We summarize them in Table 4, where 𝒗~^θsubscript^~𝒗𝜃\hat{\tilde{\bm{v}}}_{\theta}over^ start_ARG over~ start_ARG bold_italic_v end_ARG end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT denotes the stop-gradient version of 𝒗~θsubscript~𝒗𝜃\tilde{\bm{v}}_{\theta}over~ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT.

Table 4: Specification of related values and objectives under VP and SP schedule.
Formula VP SP
αγsubscript𝛼𝛾\alpha_{\gamma}italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT 11+exp(γ)11𝛾\displaystyle\sqrt{\frac{1}{1+\exp(\gamma)}}square-root start_ARG divide start_ARG 1 end_ARG start_ARG 1 + roman_exp ( italic_γ ) end_ARG end_ARG 11+exp(γ/2)11𝛾2\displaystyle\frac{1}{1+\exp(\gamma/2)}divide start_ARG 1 end_ARG start_ARG 1 + roman_exp ( italic_γ / 2 ) end_ARG
σγsubscript𝜎𝛾\sigma_{\gamma}italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT 11+exp(γ)11𝛾\displaystyle\sqrt{\frac{1}{1+\exp(-\gamma)}}square-root start_ARG divide start_ARG 1 end_ARG start_ARG 1 + roman_exp ( - italic_γ ) end_ARG end_ARG 11+exp(γ/2)11𝛾2\displaystyle\frac{1}{1+\exp(-\gamma/2)}divide start_ARG 1 end_ARG start_ARG 1 + roman_exp ( - italic_γ / 2 ) end_ARG
α˙γsubscript˙𝛼𝛾\dot{\alpha}_{\gamma}over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT 12αγσγ212subscript𝛼𝛾superscriptsubscript𝜎𝛾2\displaystyle-\frac{1}{2}\alpha_{\gamma}\sigma_{\gamma}^{2}- divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT 12αγσγ12subscript𝛼𝛾subscript𝜎𝛾\displaystyle-\frac{1}{2}\alpha_{\gamma}\sigma_{\gamma}- divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT
σ˙γsubscript˙𝜎𝛾\dot{\sigma}_{\gamma}over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT 12αγ2σγ12superscriptsubscript𝛼𝛾2subscript𝜎𝛾\displaystyle\frac{1}{2}\alpha_{\gamma}^{2}\sigma_{\gamma}divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT 12αγσγ12subscript𝛼𝛾subscript𝜎𝛾\displaystyle\frac{1}{2}\alpha_{\gamma}\sigma_{\gamma}divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT
α˙γ2+σ˙γ2superscriptsubscript˙𝛼𝛾2superscriptsubscript˙𝜎𝛾2\sqrt{\dot{\alpha}_{\gamma}^{2}+\dot{\sigma}_{\gamma}^{2}}square-root start_ARG over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG 12αγσγ12subscript𝛼𝛾subscript𝜎𝛾\displaystyle\frac{1}{2}\alpha_{\gamma}\sigma_{\gamma}divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT 12αγσγ12subscript𝛼𝛾subscript𝜎𝛾\displaystyle\frac{1}{\sqrt{2}}\alpha_{\gamma}\sigma_{\gamma}divide start_ARG 1 end_ARG start_ARG square-root start_ARG 2 end_ARG end_ARG italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT
𝒗~~𝒗\tilde{\bm{v}}over~ start_ARG bold_italic_v end_ARG αγϵσγ𝒙0subscript𝛼𝛾bold-italic-ϵsubscript𝜎𝛾subscript𝒙0\alpha_{\gamma}\bm{\epsilon}-\sigma_{\gamma}\bm{x}_{0}italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT bold_italic_ϵ - italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ϵ𝒙02bold-italic-ϵsubscript𝒙02\displaystyle\frac{\bm{\epsilon}-\bm{x}_{0}}{\sqrt{2}}divide start_ARG bold_italic_ϵ - bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG start_ARG square-root start_ARG 2 end_ARG end_ARG
𝒥FMsubscript𝒥FM\mathcal{J}_{\text{FM}}caligraphic_J start_POSTSUBSCRIPT FM end_POSTSUBSCRIPT 12γ0γTαγ2𝔼𝒙0,ϵ𝒗~θ(𝒙γ,γ)𝒗~22dγ12superscriptsubscriptsubscript𝛾0subscript𝛾𝑇superscriptsubscript𝛼𝛾2subscript𝔼subscript𝒙0bold-italic-ϵsuperscriptsubscriptnormsubscript~𝒗𝜃subscript𝒙𝛾𝛾~𝒗22differential-d𝛾\displaystyle\frac{1}{2}\int_{\gamma_{0}}^{\gamma_{T}}\alpha_{\gamma}^{2}% \mathbb{E}_{\bm{x}_{0},\bm{\epsilon}}\|\tilde{\bm{v}}_{\theta}(\bm{x}_{\gamma}% ,\gamma)-\tilde{\bm{v}}\|_{2}^{2}\mathrm{d}\gammadivide start_ARG 1 end_ARG start_ARG 2 end_ARG ∫ start_POSTSUBSCRIPT italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_γ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ end_POSTSUBSCRIPT ∥ over~ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ) - over~ start_ARG bold_italic_v end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_d italic_γ γ0γTαγ2𝔼𝒙0,ϵ𝒗~θ(𝒙γ,γ)𝒗~22dγsuperscriptsubscriptsubscript𝛾0subscript𝛾𝑇superscriptsubscript𝛼𝛾2subscript𝔼subscript𝒙0bold-italic-ϵsuperscriptsubscriptnormsubscript~𝒗𝜃subscript𝒙𝛾𝛾~𝒗22differential-d𝛾\displaystyle\int_{\gamma_{0}}^{\gamma_{T}}\alpha_{\gamma}^{2}\mathbb{E}_{\bm{% x}_{0},\bm{\epsilon}}\left\|\tilde{\bm{v}}_{\theta}(\bm{x}_{\gamma},\gamma)-% \tilde{\bm{v}}\right\|_{2}^{2}\mathrm{d}\gamma∫ start_POSTSUBSCRIPT italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_γ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ end_POSTSUBSCRIPT ∥ over~ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ) - over~ start_ARG bold_italic_v end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_d italic_γ
𝒥FM,trsubscript𝒥FMtr\mathcal{J}_{\text{FM},\mathrm{tr}}caligraphic_J start_POSTSUBSCRIPT FM , roman_tr end_POSTSUBSCRIPT 12γ0γTαγ2𝔼𝒙0,ϵ(σγtr(𝒗~θ)αγd+αγ𝒗~^θ𝒗~22)2dγ12superscriptsubscriptsubscript𝛾0subscript𝛾𝑇superscriptsubscript𝛼𝛾2subscript𝔼subscript𝒙0bold-italic-ϵsuperscriptsubscript𝜎𝛾trsubscript~𝒗𝜃subscript𝛼𝛾𝑑subscript𝛼𝛾superscriptsubscriptnormsubscript^~𝒗𝜃~𝒗222differential-d𝛾\displaystyle\frac{1}{2}\int_{\gamma_{0}}^{\gamma_{T}}\alpha_{\gamma}^{2}% \mathbb{E}_{\bm{x}_{0},\bm{\epsilon}}\left(\sigma_{\gamma}\mathrm{tr}(\nabla% \tilde{\bm{v}}_{\theta})-\alpha_{\gamma}d+\alpha_{\gamma}\|\hat{\tilde{\bm{v}}% }_{\theta}-\tilde{\bm{v}}\|_{2}^{2}\right)^{2}\mathrm{d}\gammadivide start_ARG 1 end_ARG start_ARG 2 end_ARG ∫ start_POSTSUBSCRIPT italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_γ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ end_POSTSUBSCRIPT ( italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT roman_tr ( ∇ over~ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ) - italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT italic_d + italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT ∥ over^ start_ARG over~ start_ARG bold_italic_v end_ARG end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT - over~ start_ARG bold_italic_v end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_d italic_γ γ0γTαγ2𝔼𝒙0,ϵ(σγtr(𝒗~θ)12d+2αγ𝒗~^θ𝒗~22)2dγsuperscriptsubscriptsubscript𝛾0subscript𝛾𝑇superscriptsubscript𝛼𝛾2subscript𝔼subscript𝒙0bold-italic-ϵsuperscriptsubscript𝜎𝛾trsubscript~𝒗𝜃12𝑑2subscript𝛼𝛾superscriptsubscriptnormsubscript^~𝒗𝜃~𝒗222differential-d𝛾\displaystyle\int_{\gamma_{0}}^{\gamma_{T}}\alpha_{\gamma}^{2}\mathbb{E}_{\bm{% x}_{0},\bm{\epsilon}}\left(\sigma_{\gamma}\mathrm{tr}(\nabla\tilde{\bm{v}}_{% \theta})-\frac{1}{\sqrt{2}}d+\sqrt{2}\alpha_{\gamma}\|\hat{\tilde{\bm{v}}}_{% \theta}-\tilde{\bm{v}}\|_{2}^{2}\right)^{2}\mathrm{d}\gamma∫ start_POSTSUBSCRIPT italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_γ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ end_POSTSUBSCRIPT ( italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT roman_tr ( ∇ over~ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ) - divide start_ARG 1 end_ARG start_ARG square-root start_ARG 2 end_ARG end_ARG italic_d + square-root start_ARG 2 end_ARG italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT ∥ over^ start_ARG over~ start_ARG bold_italic_v end_ARG end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT - over~ start_ARG bold_italic_v end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_d italic_γ
ϵθ(𝒙γ,γ)subscriptbold-italic-ϵ𝜃subscript𝒙𝛾𝛾\bm{\epsilon}_{\theta}(\bm{x}_{\gamma},\gamma)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ) σγ𝒙γ+αγ𝒗~θ(𝒙γ,γ)subscript𝜎𝛾subscript𝒙𝛾subscript𝛼𝛾subscript~𝒗𝜃subscript𝒙𝛾𝛾\sigma_{\gamma}\bm{x}_{\gamma}+\alpha_{\gamma}\tilde{\bm{v}}_{\theta}(\bm{x}_{% \gamma},\gamma)italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT + italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT over~ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ) 𝒙γ+2αγ𝒗~θ(𝒙γ,γ)subscript𝒙𝛾2subscript𝛼𝛾subscript~𝒗𝜃subscript𝒙𝛾𝛾\bm{x}_{\gamma}+\sqrt{2}\alpha_{\gamma}\tilde{\bm{v}}_{\theta}(\bm{x}_{\gamma}% ,\gamma)bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT + square-root start_ARG 2 end_ARG italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT over~ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ )

Next, we derive the designed IS procedure. We want to choose a proposal distribution p(γ)α˙γ2+σ˙γ2σγ2proportional-to𝑝𝛾superscriptsubscript˙𝛼𝛾2superscriptsubscript˙𝜎𝛾2superscriptsubscript𝜎𝛾2p(\gamma)\propto\frac{\dot{\alpha}_{\gamma}^{2}+\dot{\sigma}_{\gamma}^{2}}{% \sigma_{\gamma}^{2}}italic_p ( italic_γ ) ∝ divide start_ARG over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG, which is proportional αγ2superscriptsubscript𝛼𝛾2\alpha_{\gamma}^{2}italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT for VP and SP. Since we have explicit expressions for the density, we utilize inverse transform sampling to design a sampling procedure. Concretely, we take uniform samples of a number t[0,1]𝑡01t\in[0,1]italic_t ∈ [ 0 , 1 ], and solve the following equation about γtsubscript𝛾𝑡\gamma_{t}italic_γ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT:

1Zγ0γtαγ2dγ=t,Z=γ0γ1αγ2dγformulae-sequence1𝑍superscriptsubscriptsubscript𝛾0subscript𝛾𝑡superscriptsubscript𝛼𝛾2differential-d𝛾𝑡𝑍superscriptsubscriptsubscript𝛾0subscript𝛾1superscriptsubscript𝛼𝛾2differential-d𝛾\frac{1}{Z}\int_{\gamma_{0}}^{\gamma_{t}}\alpha_{\gamma}^{2}\mathrm{d}\gamma=t% ,\quad Z=\int_{\gamma_{0}}^{\gamma_{1}}\alpha_{\gamma}^{2}\mathrm{d}\gammadivide start_ARG 1 end_ARG start_ARG italic_Z end_ARG ∫ start_POSTSUBSCRIPT italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_γ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_d italic_γ = italic_t , italic_Z = ∫ start_POSTSUBSCRIPT italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_γ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_d italic_γ (52)

Here we assume maximum time T=1𝑇1T=1italic_T = 1, and Z𝑍Zitalic_Z is a normalizing constant.

VP

We have (omit the constant of the indefinite integral)

αγ2dγ=log11+exp(γ)=logαγ2superscriptsubscript𝛼𝛾2differential-d𝛾11𝛾superscriptsubscript𝛼𝛾2\int\alpha_{\gamma}^{2}\mathrm{d}\gamma=\log\frac{1}{1+\exp(-\gamma)}=\log% \alpha_{\gamma}^{2}∫ italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_d italic_γ = roman_log divide start_ARG 1 end_ARG start_ARG 1 + roman_exp ( - italic_γ ) end_ARG = roman_log italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (53)

Then the equation for inverse transform sampling is

log11+exp(γt)logαγ02=Zt,Z=logσγ12σγ02formulae-sequence11subscript𝛾𝑡superscriptsubscript𝛼subscript𝛾02𝑍𝑡𝑍superscriptsubscript𝜎subscript𝛾12superscriptsubscript𝜎subscript𝛾02\log\frac{1}{1+\exp(-\gamma_{t})}-\log\alpha_{\gamma_{0}}^{2}=Zt,\quad Z=\log% \frac{\sigma_{\gamma_{1}}^{2}}{\sigma_{\gamma_{0}}^{2}}roman_log divide start_ARG 1 end_ARG start_ARG 1 + roman_exp ( - italic_γ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG - roman_log italic_α start_POSTSUBSCRIPT italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = italic_Z italic_t , italic_Z = roman_log divide start_ARG italic_σ start_POSTSUBSCRIPT italic_γ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG (54)

The solution has a closed-form expression, which gives the inverse transformation from t𝑡titalic_t to γ𝛾\gammaitalic_γ

γt=log1exp(Zt)/σγ021,t𝒰(0,1)formulae-sequencesubscript𝛾𝑡1𝑍𝑡superscriptsubscript𝜎subscript𝛾021similar-to𝑡𝒰01\gamma_{t}=\log\frac{1}{\exp(-Zt)/\sigma_{\gamma_{0}}^{2}-1},\quad t\sim% \mathcal{U}(0,1)italic_γ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = roman_log divide start_ARG 1 end_ARG start_ARG roman_exp ( - italic_Z italic_t ) / italic_σ start_POSTSUBSCRIPT italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - 1 end_ARG , italic_t ∼ caligraphic_U ( 0 , 1 ) (55)

SP

We have (omit the constant of the indefinite integral)

αγ2dγ=2(log(1+exp(γ/2))+11+exp(γ/2))superscriptsubscript𝛼𝛾2differential-d𝛾21𝛾211𝛾2\int\alpha_{\gamma}^{2}\mathrm{d}\gamma=-2\left(\log(1+\exp(-\gamma/2))+\frac{% 1}{1+\exp(-\gamma/2)}\right)∫ italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_d italic_γ = - 2 ( roman_log ( 1 + roman_exp ( - italic_γ / 2 ) ) + divide start_ARG 1 end_ARG start_ARG 1 + roman_exp ( - italic_γ / 2 ) end_ARG ) (56)

Denote F(γ)=log(1+exp(γ/2))(1+exp(γ/2))1𝐹𝛾1𝛾2superscript1𝛾21F(\gamma)=-\log(1+\exp(-\gamma/2))-(1+\exp(-\gamma/2))^{-1}italic_F ( italic_γ ) = - roman_log ( 1 + roman_exp ( - italic_γ / 2 ) ) - ( 1 + roman_exp ( - italic_γ / 2 ) ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT, then the equation for inverse transform sampling is

F(γt)F(γ0)F(γ1)F(γ0)=t𝐹subscript𝛾𝑡𝐹subscript𝛾0𝐹subscript𝛾1𝐹subscript𝛾0𝑡\frac{F(\gamma_{t})-F(\gamma_{0})}{F(\gamma_{1})-F(\gamma_{0})}=tdivide start_ARG italic_F ( italic_γ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) - italic_F ( italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_ARG start_ARG italic_F ( italic_γ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) - italic_F ( italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_ARG = italic_t (57)

The solution has no closed-form expressions. Similar to the implementation in Song et al. (2021b), we use the bisection method to find the root.

Appendix D Illustration of velocity prediction and imbalance problem

𝐱0q0similar-tosubscript𝐱0subscript𝑞0\mathbf{x}_{0}\sim q_{0}bold_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_q start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPTϵ𝒩(𝟎,𝑰)similar-tobold-italic-ϵ𝒩0𝑰\bm{\epsilon}\sim\mathcal{N}(\mathbf{0},\bm{I})bold_italic_ϵ ∼ caligraphic_N ( bold_0 , bold_italic_I )𝒗=t𝒙t(𝒙0,ϵ)=α˙t𝒙0+σ˙tϵ𝒗subscript𝑡subscript𝒙𝑡subscript𝒙0bold-italic-ϵsubscript˙𝛼𝑡subscript𝒙0subscript˙𝜎𝑡bold-italic-ϵ\bm{v}=\partial_{t}\bm{x}_{t}(\bm{x}_{0},\bm{\epsilon})=\dot{\alpha}_{t}\bm{x}% _{0}+\dot{\sigma}_{t}\bm{\epsilon}bold_italic_v = ∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ ) = over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_ϵ𝒙t=αt𝒙0+σtϵsubscript𝒙𝑡subscript𝛼𝑡subscript𝒙0subscript𝜎𝑡bold-italic-ϵ\bm{x}_{t}=\alpha_{t}\bm{x}_{0}+\sigma_{t}\bm{\epsilon}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_ϵ (a) Illustration of velocity prediction. Left ellipse: 𝒙0subscript𝒙0\bm{x}_{0}bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT sampled from the data distribution. Right ellipse: ϵbold-italic-ϵ\bm{\epsilon}bold_italic_ϵ sampled from standard Gaussian distribution. By independently drawing a pair (𝒙0,ϵ)subscript𝒙0bold-italic-ϵ(\bm{x}_{0},\bm{\epsilon})( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ ), we can construct a diffusion path using the noise schedule.
Refer to caption (b) Mean square loss at different time t𝑡titalic_t. We plot 𝔼𝒙0,ϵ[ϵθ(𝒙t,t)ϵ22]subscript𝔼subscript𝒙0bold-italic-ϵdelimited-[]superscriptsubscriptnormsubscriptbold-italic-ϵ𝜃subscript𝒙𝑡𝑡bold-italic-ϵ22\mathbb{E}_{\bm{x}_{0},\bm{\epsilon}}\left[\|\bm{\epsilon}_{\theta}(\bm{x}_{t}% ,t)-\bm{\epsilon}\|_{2}^{2}\right]blackboard_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ end_POSTSUBSCRIPT [ ∥ bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - bold_italic_ϵ ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] and 𝔼𝒙0,ϵ[𝒗~θ(𝒙t,t)𝒗~22]subscript𝔼subscript𝒙0bold-italic-ϵdelimited-[]superscriptsubscriptnormsubscript~𝒗𝜃subscript𝒙𝑡𝑡~𝒗22\mathbb{E}_{\bm{x}_{0},\bm{\epsilon}}\left[\|\tilde{\bm{v}}_{\theta}(\bm{x}_{t% },t)-\tilde{\bm{v}}\|_{2}^{2}\right]blackboard_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ end_POSTSUBSCRIPT [ ∥ over~ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - over~ start_ARG bold_italic_v end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] for noise and velocity prediction on our pretrained model, tested on 32 data samples 𝒙0subscript𝒙0\bm{x}_{0}bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT and 20 noise samples ϵbold-italic-ϵ\bm{\epsilon}bold_italic_ϵ.
Figure 5: Illustration of velocity prediction and imbalance problem.

First, we give an intuitive illustration of our velocity parameterization and corresponding flow matching objective in Section 4.1. As shown in Figure 5(a), for each pair (𝒙0,ϵ)subscript𝒙0bold-italic-ϵ(\bm{x}_{0},\bm{\epsilon})( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ ) where 𝒙0q0(𝒙0)similar-tosubscript𝒙0subscript𝑞0subscript𝒙0\bm{x}_{0}\sim q_{0}(\bm{x}_{0})bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_q start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) and ϵ𝒩(𝟎,𝑰)similar-tobold-italic-ϵ𝒩0𝑰\bm{\epsilon}\sim\mathcal{N}(\bm{0},\bm{I})bold_italic_ϵ ∼ caligraphic_N ( bold_0 , bold_italic_I ), let 𝒙t=αt𝒙0+σtϵsubscript𝒙𝑡subscript𝛼𝑡subscript𝒙0subscript𝜎𝑡bold-italic-ϵ\bm{x}_{t}=\alpha_{t}\bm{x}_{0}+\sigma_{t}\bm{\epsilon}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_ϵ. As t𝑡titalic_t increases, 𝒙tsubscript𝒙𝑡\bm{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT moves from 𝒙0subscript𝒙0\bm{x}_{0}bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT to ϵbold-italic-ϵ\bm{\epsilon}bold_italic_ϵ gradually, forming a diffusion path in the sample space, and 𝒗𝒗\bm{v}bold_italic_v is the velocity 𝒙t(𝒙0,ϵ)tsubscript𝒙𝑡subscript𝒙0bold-italic-ϵ𝑡\frac{\partial\bm{x}_{t}(\bm{x}_{0},\bm{\epsilon})}{\partial t}divide start_ARG ∂ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ ) end_ARG start_ARG ∂ italic_t end_ARG across the path. Thus, minimizing 𝒥FMsubscript𝒥FM\mathcal{J}_{\text{FM}}caligraphic_J start_POSTSUBSCRIPT FM end_POSTSUBSCRIPT is to predict the expected velocity for all possible (𝒙0,ϵ)subscript𝒙0bold-italic-ϵ(\bm{x}_{0},\bm{\epsilon})( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ ) pairs.

Next, we interpret the superiority of velocity prediction from the perspective of balanced prediction difficulty. Intuitively, the noise prediction model suffers from an imbalance problem: at small t𝑡titalic_t, 𝒙tsubscript𝒙𝑡\bm{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is similar to data, and extracting the insignificant noise component is hard; at large t𝑡titalic_t, 𝒙tsubscript𝒙𝑡\bm{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is similar to noise, so the noise prediction is easy and has a small error. Velocity prediction, on the other hand, has a property that the prediction target 𝒗𝒗\bm{v}bold_italic_v is less relevant to input 𝒙tsubscript𝒙𝑡\bm{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. In Fig. 5(b) we empirically confirm it on our pretrained model. We plot the mean square prediction error (MSE) w.r.t. time t𝑡titalic_t, which shows that velocity prediction alleviates the imbalance problem by enlarging the training at large t𝑡titalic_t. Since the overall error is a weighted combination of the MSE at different t𝑡titalic_t and is invariant to the parameterization, we can conclude that under noise prediction, the MSE is lower near t=1𝑡1t=1italic_t = 1, but is imposed a larger weight, so it has a larger gradient variance.

Appendix E Relationship between velocity parameterization and other works

In this section, we demonstrate how the techniques in related works (Karras et al., 2022; Lipman et al., 2022; Salimans & Ho, 2022; Ho et al., 2022) can be reformulated as velocity parameterization.

E.1 Interpretation by preconditioning

Works that aim at improving the sample quality of diffusion models also consider the network parameterizations that adaptively mix signal and noise. Karras et al. (2022) proposes to precondition the neural network with a time-dependent skip connection that allows it to estimate either data 𝒙0subscript𝒙0\bm{x}_{0}bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT or noise ϵbold-italic-ϵ\bm{\epsilon}bold_italic_ϵ, or something in between. Similarly, we write the noise predictor ϵθ()subscriptbold-italic-ϵ𝜃\bm{\epsilon}_{\theta}(\cdot)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ ) in the following formulation:

ϵθ(𝒙γ,γ)=cskip(γ)𝒙γ+cout(γ)Fθ(cin(γ)𝒙γ,γ)subscriptbold-italic-ϵ𝜃subscript𝒙𝛾𝛾subscript𝑐skip𝛾subscript𝒙𝛾subscript𝑐out𝛾subscript𝐹𝜃subscript𝑐in𝛾subscript𝒙𝛾𝛾\bm{\epsilon}_{\theta}(\bm{x}_{\gamma},\gamma)=c_{\text{skip}}(\gamma)\bm{x}_{% \gamma}+c_{\text{out}}(\gamma)F_{\theta}(c_{\text{in}}(\gamma)\bm{x}_{\gamma},\gamma)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ) = italic_c start_POSTSUBSCRIPT skip end_POSTSUBSCRIPT ( italic_γ ) bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT + italic_c start_POSTSUBSCRIPT out end_POSTSUBSCRIPT ( italic_γ ) italic_F start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_c start_POSTSUBSCRIPT in end_POSTSUBSCRIPT ( italic_γ ) bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ) (58)

where Fθ()subscript𝐹𝜃F_{\theta}(\cdot)italic_F start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ ) is the pure network, 𝒙γ=αγ𝒙0+σγϵsubscript𝒙𝛾subscript𝛼𝛾subscript𝒙0subscript𝜎𝛾bold-italic-ϵ\bm{x}_{\gamma}=\alpha_{\gamma}\bm{x}_{0}+\sigma_{\gamma}\bm{\epsilon}bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT = italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT bold_italic_ϵ. The flow matching loss can be rewritten as

𝒥FM(θ)subscript𝒥FM𝜃\displaystyle\mathcal{J}_{\text{FM}}(\theta)caligraphic_J start_POSTSUBSCRIPT FM end_POSTSUBSCRIPT ( italic_θ ) =12γ0γT𝔼𝒙0,ϵ[ϵθ(𝒙γ,γ)ϵ22]absent12superscriptsubscriptsubscript𝛾0subscript𝛾𝑇subscript𝔼subscript𝒙0bold-italic-ϵdelimited-[]superscriptsubscriptnormsubscriptbold-italic-ϵ𝜃subscript𝒙𝛾𝛾bold-italic-ϵ22\displaystyle=\frac{1}{2}\int_{\gamma_{0}}^{\gamma_{T}}\mathbb{E}_{\bm{x}_{0},% \bm{\epsilon}}\left[\|\bm{\epsilon}_{\theta}(\bm{x}_{\gamma},\gamma)-\bm{% \epsilon}\|_{2}^{2}\right]= divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∫ start_POSTSUBSCRIPT italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_γ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ end_POSTSUBSCRIPT [ ∥ bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ) - bold_italic_ϵ ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] (59)
=12γ0γT𝔼𝒙0,ϵ[cskip(γ)𝒙γ+cout(γ)Fθ(cin(γ)𝒙γ,γ)ϵ22]absent12superscriptsubscriptsubscript𝛾0subscript𝛾𝑇subscript𝔼subscript𝒙0bold-italic-ϵdelimited-[]superscriptsubscriptnormsubscript𝑐skip𝛾subscript𝒙𝛾subscript𝑐out𝛾subscript𝐹𝜃subscript𝑐in𝛾subscript𝒙𝛾𝛾bold-italic-ϵ22\displaystyle=\frac{1}{2}\int_{\gamma_{0}}^{\gamma_{T}}\mathbb{E}_{\bm{x}_{0},% \bm{\epsilon}}\left[\|c_{\text{skip}}(\gamma)\bm{x}_{\gamma}+c_{\text{out}}(% \gamma)F_{\theta}(c_{\text{in}}(\gamma)\bm{x}_{\gamma},\gamma)-\bm{\epsilon}\|% _{2}^{2}\right]= divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∫ start_POSTSUBSCRIPT italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_γ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ end_POSTSUBSCRIPT [ ∥ italic_c start_POSTSUBSCRIPT skip end_POSTSUBSCRIPT ( italic_γ ) bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT + italic_c start_POSTSUBSCRIPT out end_POSTSUBSCRIPT ( italic_γ ) italic_F start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_c start_POSTSUBSCRIPT in end_POSTSUBSCRIPT ( italic_γ ) bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ) - bold_italic_ϵ ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ]
=12γ0γT𝔼𝒙0,ϵ[cout(γ)2Fθ(cin(γ)𝒙γ,γ)Ftarget(𝒙0,ϵ,γ)22]absent12superscriptsubscriptsubscript𝛾0subscript𝛾𝑇subscript𝔼subscript𝒙0bold-italic-ϵdelimited-[]subscript𝑐outsuperscript𝛾2superscriptsubscriptnormsubscript𝐹𝜃subscript𝑐in𝛾subscript𝒙𝛾𝛾subscript𝐹targetsubscript𝒙0bold-italic-ϵ𝛾22\displaystyle=\frac{1}{2}\int_{\gamma_{0}}^{\gamma_{T}}\mathbb{E}_{\bm{x}_{0},% \bm{\epsilon}}\left[c_{\text{out}}(\gamma)^{2}\|F_{\theta}(c_{\text{in}}(% \gamma)\bm{x}_{\gamma},\gamma)-F_{\text{target}}(\bm{x}_{0},\bm{\epsilon},% \gamma)\|_{2}^{2}\right]= divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∫ start_POSTSUBSCRIPT italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_γ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ end_POSTSUBSCRIPT [ italic_c start_POSTSUBSCRIPT out end_POSTSUBSCRIPT ( italic_γ ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∥ italic_F start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_c start_POSTSUBSCRIPT in end_POSTSUBSCRIPT ( italic_γ ) bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ) - italic_F start_POSTSUBSCRIPT target end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ , italic_γ ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ]

where

Ftarget(𝒙0,ϵ,γ)=ϵcskip(γ)𝒙γcout(γ)subscript𝐹targetsubscript𝒙0bold-italic-ϵ𝛾bold-italic-ϵsubscript𝑐skip𝛾subscript𝒙𝛾subscript𝑐out𝛾F_{\text{target}}(\bm{x}_{0},\bm{\epsilon},\gamma)=\frac{\bm{\epsilon}-c_{% \text{skip}}(\gamma)\bm{x}_{\gamma}}{c_{\text{out}}(\gamma)}italic_F start_POSTSUBSCRIPT target end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ , italic_γ ) = divide start_ARG bold_italic_ϵ - italic_c start_POSTSUBSCRIPT skip end_POSTSUBSCRIPT ( italic_γ ) bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT end_ARG start_ARG italic_c start_POSTSUBSCRIPT out end_POSTSUBSCRIPT ( italic_γ ) end_ARG (60)

Following first principles in EDM, We derive formulas for cin(γ),cout(γ),cskip(γ)subscript𝑐in𝛾subscript𝑐out𝛾subscript𝑐skip𝛾c_{\text{in}}(\gamma),c_{\text{out}}(\gamma),c_{\text{skip}}(\gamma)italic_c start_POSTSUBSCRIPT in end_POSTSUBSCRIPT ( italic_γ ) , italic_c start_POSTSUBSCRIPT out end_POSTSUBSCRIPT ( italic_γ ) , italic_c start_POSTSUBSCRIPT skip end_POSTSUBSCRIPT ( italic_γ ) to ensure:

  1. 1.

    The training inputs of Fθ()subscript𝐹𝜃F_{\theta}(\cdot)italic_F start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ ) have unit variance.

  2. 2.

    The effective training target Ftargetsubscript𝐹targetF_{\text{target}}italic_F start_POSTSUBSCRIPT target end_POSTSUBSCRIPT has unit variance.

  3. 3.

    We select cskip(γ)subscript𝑐skip𝛾c_{\text{skip}}(\gamma)italic_c start_POSTSUBSCRIPT skip end_POSTSUBSCRIPT ( italic_γ ) to minimize cout(γ)subscript𝑐out𝛾c_{\text{out}}(\gamma)italic_c start_POSTSUBSCRIPT out end_POSTSUBSCRIPT ( italic_γ ), so that the errors of Fθsubscript𝐹𝜃F_{\theta}italic_F start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT are amplified as little as possible.

From principle 1, we have

11\displaystyle 11 =Var[cin(γ)𝒙γ]absentVardelimited-[]subscript𝑐in𝛾subscript𝒙𝛾\displaystyle=\mbox{Var}\left[c_{\text{in}}(\gamma)\bm{x}_{\gamma}\right]= Var [ italic_c start_POSTSUBSCRIPT in end_POSTSUBSCRIPT ( italic_γ ) bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT ] (61)
11\displaystyle 11 =Var[cin(γ)(αγ𝒙0+σγϵ)]absentVardelimited-[]subscript𝑐in𝛾subscript𝛼𝛾subscript𝒙0subscript𝜎𝛾bold-italic-ϵ\displaystyle=\mbox{Var}\left[c_{\text{in}}(\gamma)(\alpha_{\gamma}\bm{x}_{0}+% \sigma_{\gamma}\bm{\epsilon})\right]= Var [ italic_c start_POSTSUBSCRIPT in end_POSTSUBSCRIPT ( italic_γ ) ( italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT bold_italic_ϵ ) ]
11\displaystyle 11 =cin2(γ)(αγ2σdata2+σγ2)absentsuperscriptsubscript𝑐in2𝛾superscriptsubscript𝛼𝛾2superscriptsubscript𝜎data2superscriptsubscript𝜎𝛾2\displaystyle=c_{\text{in}}^{2}(\gamma)(\alpha_{\gamma}^{2}\sigma_{\text{data}% }^{2}+\sigma_{\gamma}^{2})= italic_c start_POSTSUBSCRIPT in end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_γ ) ( italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_σ start_POSTSUBSCRIPT data end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )
cin(γ)subscript𝑐in𝛾\displaystyle c_{\text{in}}(\gamma)italic_c start_POSTSUBSCRIPT in end_POSTSUBSCRIPT ( italic_γ ) =1σγ2+σdata2αγ2absent1superscriptsubscript𝜎𝛾2superscriptsubscript𝜎data2superscriptsubscript𝛼𝛾2\displaystyle=\frac{1}{\sqrt{\sigma_{\gamma}^{2}+\sigma_{\text{data}}^{2}% \alpha_{\gamma}^{2}}}= divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_σ start_POSTSUBSCRIPT data end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG end_ARG

From principle 2, we have

11\displaystyle 11 =Var[Ftarget(𝒙0,ϵ,γ)]absentVardelimited-[]subscript𝐹targetsubscript𝒙0bold-italic-ϵ𝛾\displaystyle=\mbox{Var}\left[F_{\text{target}}(\bm{x}_{0},\bm{\epsilon},% \gamma)\right]= Var [ italic_F start_POSTSUBSCRIPT target end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ , italic_γ ) ] (62)
11\displaystyle 11 =Var[ϵcskip(γ)𝒙γcout(γ)]absentVardelimited-[]bold-italic-ϵsubscript𝑐skip𝛾subscript𝒙𝛾subscript𝑐out𝛾\displaystyle=\mbox{Var}\left[\frac{\bm{\epsilon}-c_{\text{skip}}(\gamma)\bm{x% }_{\gamma}}{c_{\text{out}}(\gamma)}\right]= Var [ divide start_ARG bold_italic_ϵ - italic_c start_POSTSUBSCRIPT skip end_POSTSUBSCRIPT ( italic_γ ) bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT end_ARG start_ARG italic_c start_POSTSUBSCRIPT out end_POSTSUBSCRIPT ( italic_γ ) end_ARG ]
cout2(γ)superscriptsubscript𝑐out2𝛾\displaystyle c_{\text{out}}^{2}(\gamma)italic_c start_POSTSUBSCRIPT out end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_γ ) =Var[ϵcskip(γ)𝒙γ]absentVardelimited-[]bold-italic-ϵsubscript𝑐skip𝛾subscript𝒙𝛾\displaystyle=\mbox{Var}\left[\bm{\epsilon}-c_{\text{skip}}(\gamma)\bm{x}_{% \gamma}\right]= Var [ bold_italic_ϵ - italic_c start_POSTSUBSCRIPT skip end_POSTSUBSCRIPT ( italic_γ ) bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT ]
cout2(γ)superscriptsubscript𝑐out2𝛾\displaystyle c_{\text{out}}^{2}(\gamma)italic_c start_POSTSUBSCRIPT out end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_γ ) =Var[ϵcskip(γ)(αγ𝒙0+σγϵ)]absentVardelimited-[]bold-italic-ϵsubscript𝑐skip𝛾subscript𝛼𝛾subscript𝒙0subscript𝜎𝛾bold-italic-ϵ\displaystyle=\mbox{Var}\left[\bm{\epsilon}-c_{\text{skip}}(\gamma)(\alpha_{% \gamma}\bm{x}_{0}+\sigma_{\gamma}\bm{\epsilon})\right]= Var [ bold_italic_ϵ - italic_c start_POSTSUBSCRIPT skip end_POSTSUBSCRIPT ( italic_γ ) ( italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT bold_italic_ϵ ) ]
cout2(γ)superscriptsubscript𝑐out2𝛾\displaystyle c_{\text{out}}^{2}(\gamma)italic_c start_POSTSUBSCRIPT out end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_γ ) =Var[(1cskip(γ)σγ)ϵcskip(γ)αγ𝒙0]absentVardelimited-[]1subscript𝑐skip𝛾subscript𝜎𝛾bold-italic-ϵsubscript𝑐skip𝛾subscript𝛼𝛾subscript𝒙0\displaystyle=\mbox{Var}\left[(1-c_{\text{skip}}(\gamma)\sigma_{\gamma})\bm{% \epsilon}-c_{\text{skip}}(\gamma)\alpha_{\gamma}\bm{x}_{0}\right]= Var [ ( 1 - italic_c start_POSTSUBSCRIPT skip end_POSTSUBSCRIPT ( italic_γ ) italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT ) bold_italic_ϵ - italic_c start_POSTSUBSCRIPT skip end_POSTSUBSCRIPT ( italic_γ ) italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ]
cout2(γ)superscriptsubscript𝑐out2𝛾\displaystyle c_{\text{out}}^{2}(\gamma)italic_c start_POSTSUBSCRIPT out end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_γ ) =(1cskip(γ)σγ)2+cskip2(γ)αγ2σdata2absentsuperscript1subscript𝑐skip𝛾subscript𝜎𝛾2subscriptsuperscript𝑐2skip𝛾superscriptsubscript𝛼𝛾2superscriptsubscript𝜎data2\displaystyle=(1-c_{\text{skip}}(\gamma)\sigma_{\gamma})^{2}+c^{2}_{\text{skip% }}(\gamma)\alpha_{\gamma}^{2}\sigma_{\text{data}}^{2}= ( 1 - italic_c start_POSTSUBSCRIPT skip end_POSTSUBSCRIPT ( italic_γ ) italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_c start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT skip end_POSTSUBSCRIPT ( italic_γ ) italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_σ start_POSTSUBSCRIPT data end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

From principle 3, we have

00\displaystyle 0 =dcout2(γ)dcskip(γ)absentdsuperscriptsubscript𝑐out2𝛾dsubscript𝑐skip𝛾\displaystyle=\frac{\mathrm{d}c_{\text{out}}^{2}(\gamma)}{\mathrm{d}c_{\text{% skip}}(\gamma)}= divide start_ARG roman_d italic_c start_POSTSUBSCRIPT out end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_γ ) end_ARG start_ARG roman_d italic_c start_POSTSUBSCRIPT skip end_POSTSUBSCRIPT ( italic_γ ) end_ARG (63)
00\displaystyle 0 =2σγ(1σγcskip(γ))+2αγ2σdata2cskip(γ)absent2subscript𝜎𝛾1subscript𝜎𝛾subscript𝑐skip𝛾2superscriptsubscript𝛼𝛾2superscriptsubscript𝜎data2subscript𝑐skip𝛾\displaystyle=-2\sigma_{\gamma}(1-\sigma_{\gamma}c_{\text{skip}}(\gamma))+2% \alpha_{\gamma}^{2}\sigma_{\text{data}}^{2}c_{\text{skip}}(\gamma)= - 2 italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT ( 1 - italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT skip end_POSTSUBSCRIPT ( italic_γ ) ) + 2 italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_σ start_POSTSUBSCRIPT data end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_c start_POSTSUBSCRIPT skip end_POSTSUBSCRIPT ( italic_γ )
cskip(γ)subscript𝑐skip𝛾\displaystyle c_{\text{skip}}(\gamma)italic_c start_POSTSUBSCRIPT skip end_POSTSUBSCRIPT ( italic_γ ) =σγσγ2+σdata2αγ2absentsubscript𝜎𝛾superscriptsubscript𝜎𝛾2superscriptsubscript𝜎data2superscriptsubscript𝛼𝛾2\displaystyle=\frac{\sigma_{\gamma}}{\sigma_{\gamma}^{2}+\sigma_{\text{data}}^% {2}\alpha_{\gamma}^{2}}= divide start_ARG italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_σ start_POSTSUBSCRIPT data end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG

We now substitute Eqn. (63) into Eqn. (62) to obtain the formula for cout(γ)subscript𝑐out𝛾c_{\text{out}}(\gamma)italic_c start_POSTSUBSCRIPT out end_POSTSUBSCRIPT ( italic_γ ):

cout(γ)=σdataαγσγ2+σdata2αγ2subscript𝑐out𝛾subscript𝜎datasubscript𝛼𝛾superscriptsubscript𝜎𝛾2superscriptsubscript𝜎data2superscriptsubscript𝛼𝛾2c_{\text{out}}(\gamma)=\frac{\sigma_{\text{data}}\alpha_{\gamma}}{\sqrt{\sigma% _{\gamma}^{2}+\sigma_{\text{data}}^{2}\alpha_{\gamma}^{2}}}italic_c start_POSTSUBSCRIPT out end_POSTSUBSCRIPT ( italic_γ ) = divide start_ARG italic_σ start_POSTSUBSCRIPT data end_POSTSUBSCRIPT italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT end_ARG start_ARG square-root start_ARG italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_σ start_POSTSUBSCRIPT data end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG end_ARG (64)

If we assume σdata=1subscript𝜎data1\sigma_{\text{data}}=1italic_σ start_POSTSUBSCRIPT data end_POSTSUBSCRIPT = 1 and consider VP schedule, we have αγ2+σγ2=1superscriptsubscript𝛼𝛾2superscriptsubscript𝜎𝛾21\alpha_{\gamma}^{2}+\sigma_{\gamma}^{2}=1italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = 1, and the coefficients are reduced to

cin(γ)=1,cskip(γ)=σγ,cout(γ)=αγformulae-sequencesubscript𝑐in𝛾1formulae-sequencesubscript𝑐skip𝛾subscript𝜎𝛾subscript𝑐out𝛾subscript𝛼𝛾c_{\text{in}}(\gamma)=1,\quad c_{\text{skip}}(\gamma)=\sigma_{\gamma},\quad c_% {\text{out}}(\gamma)=\alpha_{\gamma}italic_c start_POSTSUBSCRIPT in end_POSTSUBSCRIPT ( italic_γ ) = 1 , italic_c start_POSTSUBSCRIPT skip end_POSTSUBSCRIPT ( italic_γ ) = italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_c start_POSTSUBSCRIPT out end_POSTSUBSCRIPT ( italic_γ ) = italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT (65)

In this case, the preconditioning is in agreement with our velocity parameterization by 𝒗~θ(𝒙γ,γ)=Fθ(𝒙γ,γ)subscript~𝒗𝜃subscript𝒙𝛾𝛾subscript𝐹𝜃subscript𝒙𝛾𝛾\tilde{\bm{v}}_{\theta}(\bm{x}_{\gamma},\gamma)=F_{\theta}(\bm{x}_{\gamma},\gamma)over~ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ) = italic_F start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ). In practice, we find setting σdata=0.5subscript𝜎data0.5\sigma_{\text{data}}=0.5italic_σ start_POSTSUBSCRIPT data end_POSTSUBSCRIPT = 0.5 as in Karras et al. (2022) leads to faster descent of the loss at the start, but slower convergence as the training proceeds.

E.2 Connection to flow matching in Lipman et al. (2022)

Lipman et al. (2022) defines a conditional probability path pt(𝒙|𝒙0)subscript𝑝𝑡conditional𝒙subscript𝒙0p_{t}(\bm{x}|\bm{x}_{0})italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) that gradually moves the data 𝒙0q(𝒙0)similar-tosubscript𝒙0𝑞subscript𝒙0\bm{x}_{0}\sim q(\bm{x}_{0})bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_q ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) to a target distribution p1(𝒙)subscript𝑝1𝒙p_{1}(\bm{x})italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( bold_italic_x ). Note that they use t=1𝑡1t=1italic_t = 1 to represent data distribution and t=0𝑡0t=0italic_t = 0 to represent target distribution. To be consistent, we reverse their time representation. They obtain the marginal probability path by marginalizing over q(𝒙0)𝑞subscript𝒙0q(\bm{x}_{0})italic_q ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ):

pt(𝒙)=pt(𝒙|𝒙0)q(𝒙0)d𝒙0subscript𝑝𝑡𝒙subscript𝑝𝑡conditional𝒙subscript𝒙0𝑞subscript𝒙0differential-dsubscript𝒙0p_{t}(\bm{x})=\int p_{t}(\bm{x}|\bm{x}_{0})q(\bm{x}_{0})\mathrm{d}\bm{x}_{0}italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x ) = ∫ italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_q ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) roman_d bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT (66)

They want to learn a vector field 𝒗t(𝒙)subscript𝒗𝑡𝒙\bm{v}_{t}(\bm{x})bold_italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x ), which defines a flow ϕ:[0,1]×dd:bold-italic-ϕ01superscript𝑑superscript𝑑\bm{\phi}:[0,1]\times\mathbb{R}^{d}\rightarrow\mathbb{R}^{d}bold_italic_ϕ : [ 0 , 1 ] × blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT by

ddtϕt(𝒙)=𝒗t(ϕt(𝒙)),ϕ1(𝒙)=𝒙formulae-sequencedd𝑡subscriptbold-italic-ϕ𝑡𝒙subscript𝒗𝑡subscriptbold-italic-ϕ𝑡𝒙subscriptbold-italic-ϕ1𝒙𝒙\frac{\mathrm{d}}{\mathrm{d}t}\bm{\phi}_{t}(\bm{x})=\bm{v}_{t}(\bm{\phi}_{t}(% \bm{x})),\quad\bm{\phi}_{1}(\bm{x})=\bm{x}divide start_ARG roman_d end_ARG start_ARG roman_d italic_t end_ARG bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x ) = bold_italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x ) ) , bold_italic_ϕ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( bold_italic_x ) = bold_italic_x (67)

so that the marginal ptsubscript𝑝𝑡p_{t}italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT can be generated by the push-forward pt=[ϕt]*p1subscript𝑝𝑡subscriptdelimited-[]subscriptbold-italic-ϕ𝑡subscript𝑝1p_{t}=[\bm{\phi}_{t}]_{*}p_{1}italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = [ bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT * end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT. In practice, they consider the Gaussian conditional probability paths

pt(𝒙|𝒙0)=𝒩(𝒙|𝝁t(𝒙0),σt2(𝒙0)𝑰)subscript𝑝𝑡conditional𝒙subscript𝒙0𝒩conditional𝒙subscript𝝁𝑡subscript𝒙0superscriptsubscript𝜎𝑡2subscript𝒙0𝑰p_{t}(\bm{x}|\bm{x}_{0})=\mathcal{N}(\bm{x}|\bm{\mu}_{t}(\bm{x}_{0}),\sigma_{t% }^{2}(\bm{x}_{0})\bm{I})italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = caligraphic_N ( bold_italic_x | bold_italic_μ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) , italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) bold_italic_I ) (68)

and propose a conditional flow matching (CFM) objective for simulation-free training of 𝒗t(𝒙)subscript𝒗𝑡𝒙\bm{v}_{t}(\bm{x})bold_italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x )

CFM(θ)=𝔼t,q(𝒙0),pt(𝒙|𝒙0)𝒗t(𝒙)𝒖t(𝒙|𝒙0)22\mathcal{L}_{\text{CFM}}(\theta)=\mathbb{E}_{t,q(\bm{x}_{0}),p_{t}(\bm{x}|\bm{% x}_{0})}\|\bm{v}_{t}(\bm{x})-\bm{u}_{t}(\bm{x}|\bm{x}_{0})\|_{2}^{2}caligraphic_L start_POSTSUBSCRIPT CFM end_POSTSUBSCRIPT ( italic_θ ) = blackboard_E start_POSTSUBSCRIPT italic_t , italic_q ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) , italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT ∥ bold_italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x ) - bold_italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (69)

where

𝒖t(𝒙|𝒙0)=σt(𝒙0)σt(𝒙0)(𝒙𝝁t(𝒙0))+𝝁t(𝒙0)subscript𝒖𝑡conditional𝒙subscript𝒙0subscriptsuperscript𝜎𝑡subscript𝒙0subscript𝜎𝑡subscript𝒙0𝒙subscript𝝁𝑡subscript𝒙0subscriptsuperscript𝝁𝑡subscript𝒙0\bm{u}_{t}(\bm{x}|\bm{x}_{0})=\frac{\sigma^{\prime}_{t}(\bm{x}_{0})}{\sigma_{t% }(\bm{x}_{0})}(\bm{x}-\bm{\mu}_{t}(\bm{x}_{0}))+\bm{\mu}^{\prime}_{t}(\bm{x}_{% 0})bold_italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = divide start_ARG italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_ARG ( bold_italic_x - bold_italic_μ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ) + bold_italic_μ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) (70)

Suppose the mean 𝝁t(𝒙0)subscript𝝁𝑡subscript𝒙0\bm{\mu}_{t}(\bm{x}_{0})bold_italic_μ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) is linear to 𝒙0subscript𝒙0\bm{x}_{0}bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, and the standard deviation σt(𝒙0)subscript𝜎𝑡subscript𝒙0\sigma_{t}(\bm{x}_{0})italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) is invariant to 𝒙0subscript𝒙0\bm{x}_{0}bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, as the two experimented cases in flow matching. By setting 𝝁t(𝒙0)=αt𝒙0,σt(𝒙0)=σtformulae-sequencesubscript𝝁𝑡subscript𝒙0subscript𝛼𝑡subscript𝒙0subscript𝜎𝑡subscript𝒙0subscript𝜎𝑡\bm{\mu}_{t}(\bm{x}_{0})=\alpha_{t}\bm{x}_{0},\sigma_{t}(\bm{x}_{0})=\sigma_{t}bold_italic_μ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, we have

𝒖t(𝒙|𝒙0)=σ˙tσt(𝒙αt𝒙0)+α˙t𝒙0=α˙t𝒙0+σ˙tϵsubscript𝒖𝑡conditional𝒙subscript𝒙0subscript˙𝜎𝑡subscript𝜎𝑡𝒙subscript𝛼𝑡subscript𝒙0subscript˙𝛼𝑡subscript𝒙0subscript˙𝛼𝑡subscript𝒙0subscript˙𝜎𝑡bold-italic-ϵ\bm{u}_{t}(\bm{x}|\bm{x}_{0})=\frac{\dot{\sigma}_{t}}{\sigma_{t}}(\bm{x}-% \alpha_{t}\bm{x}_{0})+\dot{\alpha}_{t}\bm{x}_{0}=\dot{\alpha}_{t}\bm{x}_{0}+% \dot{\sigma}_{t}\bm{\epsilon}bold_italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = divide start_ARG over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG ( bold_italic_x - italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) + over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_ϵ (71)

where we use 𝒙=αt𝒙0+σtϵ,ϵ𝒩(𝟎,𝑰)formulae-sequence𝒙subscript𝛼𝑡subscript𝒙0subscript𝜎𝑡bold-italic-ϵsimilar-tobold-italic-ϵ𝒩0𝑰\bm{x}=\alpha_{t}\bm{x}_{0}+\sigma_{t}\bm{\epsilon},\bm{\epsilon}\sim\mathcal{% N}(\bm{0},\bm{I})bold_italic_x = italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_ϵ , bold_italic_ϵ ∼ caligraphic_N ( bold_0 , bold_italic_I ) since pt(𝒙|𝒙0)=𝒩(𝒙|αt𝒙0,σt2𝑰)subscript𝑝𝑡conditional𝒙subscript𝒙0𝒩conditional𝒙subscript𝛼𝑡subscript𝒙0superscriptsubscript𝜎𝑡2𝑰p_{t}(\bm{x}|\bm{x}_{0})=\mathcal{N}(\bm{x}|\alpha_{t}\bm{x}_{0},\sigma_{t}^{2% }\bm{I})italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = caligraphic_N ( bold_italic_x | italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ). Then we can observe that they are corresponding to our notations: the conditional probability path pt(𝒙|𝒙0)subscript𝑝𝑡conditional𝒙subscript𝒙0p_{t}(\bm{x}|\bm{x}_{0})italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) corresponds to the Gaussian transition kernel q0t(𝒙t|𝒙0)subscript𝑞0𝑡conditionalsubscript𝒙𝑡subscript𝒙0q_{0t}(\bm{x}_{t}|\bm{x}_{0})italic_q start_POSTSUBSCRIPT 0 italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) of the forward diffusion process; the marginal probability path pt(𝒙)subscript𝑝𝑡𝒙p_{t}(\bm{x})italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x ) corresponds to the ground-truth marginals qt(𝒙t)subscript𝑞𝑡subscript𝒙𝑡q_{t}(\bm{x}_{t})italic_q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) associated with the forward diffusion process; the matching target 𝒖t(𝒙|𝒙0)subscript𝒖𝑡conditional𝒙subscript𝒙0\bm{u}_{t}(\bm{x}|\bm{x}_{0})bold_italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) in CFM corresponds to the velocity of the diffusion path 𝒗=α˙t𝒙0+σ˙tϵ𝒗subscript˙𝛼𝑡subscript𝒙0subscript˙𝜎𝑡bold-italic-ϵ\bm{v}=\dot{\alpha}_{t}\bm{x}_{0}+\dot{\sigma}_{t}\bm{\epsilon}bold_italic_v = over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_ϵ in our formulation.

Therefore, the CFM objective in Lipman et al. (2022) is actually velocity parameterization when specific to Gaussian diffusion processes, which is similar to our first-order objective of the pretraining phase. We can express CFM in a simpler form, which is easier to analyze and generalize to any noise schedule. Then by the equivalence of different predictors (Theorem B.1) and the relationship between f(t),g(t)𝑓𝑡𝑔𝑡f(t),g(t)italic_f ( italic_t ) , italic_g ( italic_t ) and αt,σtsubscript𝛼𝑡subscript𝜎𝑡\alpha_{t},\sigma_{t}italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, we have

CFM(θ)subscriptCFM𝜃\displaystyle\mathcal{L}_{\text{CFM}}(\theta)caligraphic_L start_POSTSUBSCRIPT CFM end_POSTSUBSCRIPT ( italic_θ ) =0T𝔼𝒙0,ϵ𝒗θ(𝒙t,t)𝒗22dtabsentsuperscriptsubscript0𝑇subscript𝔼subscript𝒙0bold-italic-ϵsuperscriptsubscriptnormsubscript𝒗𝜃subscript𝒙𝑡𝑡𝒗22differential-d𝑡\displaystyle=\int_{0}^{T}\mathbb{E}_{\bm{x}_{0},\bm{\epsilon}}\|\bm{v}_{% \theta}(\bm{x}_{t},t)-\bm{v}\|_{2}^{2}\mathrm{d}t= ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ end_POSTSUBSCRIPT ∥ bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - bold_italic_v ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_d italic_t (72)
=0T𝔼𝒙0,ϵf(t)𝒙t12g2(t)𝒔θ(𝒙t,t)(α˙t𝒙0+σ˙tϵ)22dtabsentsuperscriptsubscript0𝑇subscript𝔼subscript𝒙0bold-italic-ϵsuperscriptsubscriptnorm𝑓𝑡subscript𝒙𝑡12superscript𝑔2𝑡subscript𝒔𝜃subscript𝒙𝑡𝑡subscript˙𝛼𝑡subscript𝒙0subscript˙𝜎𝑡bold-italic-ϵ22differential-d𝑡\displaystyle=\int_{0}^{T}\mathbb{E}_{\bm{x}_{0},\bm{\epsilon}}\left\|f(t)\bm{% x}_{t}-\frac{1}{2}g^{2}(t)\bm{s}_{\theta}(\bm{x}_{t},t)-(\dot{\alpha}_{t}\bm{x% }_{0}+\dot{\sigma}_{t}\bm{\epsilon})\right\|_{2}^{2}\mathrm{d}t= ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ end_POSTSUBSCRIPT ∥ italic_f ( italic_t ) bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) bold_italic_s start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - ( over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_ϵ ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_d italic_t
=0T14g4(t)𝔼𝒙0,ϵ𝒔θ(𝒙t,t)+ϵσtdtabsentsuperscriptsubscript0𝑇14superscript𝑔4𝑡subscript𝔼subscript𝒙0bold-italic-ϵnormsubscript𝒔𝜃subscript𝒙𝑡𝑡bold-italic-ϵsubscript𝜎𝑡differential-d𝑡\displaystyle=\int_{0}^{T}\frac{1}{4}g^{4}(t)\mathbb{E}_{\bm{x}_{0},\bm{% \epsilon}}\left\|\bm{s}_{\theta}(\bm{x}_{t},t)+\frac{\bm{\epsilon}}{\sigma_{t}% }\right\|\mathrm{d}t= ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG 4 end_ARG italic_g start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT ( italic_t ) blackboard_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ end_POSTSUBSCRIPT ∥ bold_italic_s start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) + divide start_ARG bold_italic_ϵ end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG ∥ roman_d italic_t

which demonstrates that the CFM objective not only changes the parameterization but also imposes a different time weighting w(t)=14g4(t)𝑤𝑡14superscript𝑔4𝑡w(t)=\frac{1}{4}g^{4}(t)italic_w ( italic_t ) = divide start_ARG 1 end_ARG start_ARG 4 end_ARG italic_g start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT ( italic_t ) on the original denoising score matching objective. When the training aims to improve the sample quality (e.g., FID), the optimal choice for w(t)𝑤𝑡w(t)italic_w ( italic_t ) is still an open problem.

Comparing the CFM objective to our first-order objective Eqn. (50), the practical differences are that we use normalized predictor 𝒗~θsubscript~𝒗𝜃\tilde{\bm{v}}_{\theta}over~ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT, γ𝛾\gammaitalic_γ timing, and apply likelihood weighting. The likelihood weighting refers to time weighting w(t)=g2(t)2σt2𝑤𝑡superscript𝑔2𝑡2superscriptsubscript𝜎𝑡2w(t)=\frac{g^{2}(t)}{2\sigma_{t}^{2}}italic_w ( italic_t ) = divide start_ARG italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) end_ARG start_ARG 2 italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG in Eqn. (5) and w(t)=2g2(t)𝑤𝑡2superscript𝑔2𝑡w(t)=\frac{2}{g^{2}(t)}italic_w ( italic_t ) = divide start_ARG 2 end_ARG start_ARG italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) end_ARG in Eqn. (12), which is consistent under different parameterizations and is the theoretically optimal choice for maximum likelihood training (Song et al., 2021c). Also, changing the time domain from t𝑡titalic_t to γ𝛾\gammaitalic_γ will not alter the value of the objective, but will affect the variance of Monte-Carlo estimation and the convergence speed, as we have discussed. For example, the OT path in Lipman et al. (2022) is αt=1t,σt=1(1σmin)(1t)tformulae-sequencesubscript𝛼𝑡1𝑡subscript𝜎𝑡11subscript𝜎1𝑡𝑡\alpha_{t}=1-t,\sigma_{t}=1-(1-\sigma_{\min})(1-t)\approx titalic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = 1 - italic_t , italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = 1 - ( 1 - italic_σ start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT ) ( 1 - italic_t ) ≈ italic_t, and the relation between γ𝛾\gammaitalic_γ and t𝑡titalic_t is γ=log(σt2/αt2)=2log(t/(1t))𝛾superscriptsubscript𝜎𝑡2superscriptsubscript𝛼𝑡22𝑡1𝑡\gamma=\log(\sigma_{t}^{2}/\alpha_{t}^{2})=2\log(t/(1-t))italic_γ = roman_log ( italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) = 2 roman_log ( italic_t / ( 1 - italic_t ) ). Under γ𝛾\gammaitalic_γ timing, we can decouple the choice of noise schedules to the greatest extent, and regard the change of variable from γ𝛾\gammaitalic_γ to t𝑡titalic_t as a tunable importance sampling procedure.

Besides, normalizing the field is necessary for stable training of the velocity predictor and is the key to unifying v prediction and preconditioning. Such strategies have also been adopted in more general physics-inspired generative models. For example, Xu et al. (2022, 2023) propose to normalize the Poission field when training Poisson flow generative models.

E.3 Connection to v prediction

In Salimans & Ho (2022); Ho et al. (2022), a technique called “v prediction” is used, which parameterizes a network to predict 𝐯=αtϵσt𝒙0𝐯subscript𝛼𝑡bold-italic-ϵsubscript𝜎𝑡subscript𝒙0\mathbf{v}=\alpha_{t}\bm{\epsilon}-\sigma_{t}\bm{x}_{0}bold_v = italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_ϵ - italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. Assuming a VP schedule following their choice, we have αt2+σt2=1superscriptsubscript𝛼𝑡2superscriptsubscript𝜎𝑡21\alpha_{t}^{2}+\sigma_{t}^{2}=1italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = 1, so by taking the derivative w.r.t. t𝑡titalic_t we have αtα˙t+σtσ˙t=0subscript𝛼𝑡subscript˙𝛼𝑡subscript𝜎𝑡subscript˙𝜎𝑡0\alpha_{t}\dot{\alpha}_{t}+\sigma_{t}\dot{\sigma}_{t}=0italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = 0, then

α˙t=σtσ˙tαt,dlogαtdt=α˙tαt=σtσ˙tαt2formulae-sequencesubscript˙𝛼𝑡subscript𝜎𝑡subscript˙𝜎𝑡subscript𝛼𝑡dsubscript𝛼𝑡d𝑡subscript˙𝛼𝑡subscript𝛼𝑡subscript𝜎𝑡subscript˙𝜎𝑡superscriptsubscript𝛼𝑡2\dot{\alpha}_{t}=-\frac{\sigma_{t}\dot{\sigma}_{t}}{\alpha_{t}},\quad\frac{% \mathrm{d}\log\alpha_{t}}{\mathrm{d}t}=\frac{\dot{\alpha}_{t}}{\alpha_{t}}=-% \frac{\sigma_{t}\dot{\sigma}_{t}}{\alpha_{t}^{2}}over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = - divide start_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG , divide start_ARG roman_d roman_log italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG roman_d italic_t end_ARG = divide start_ARG over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG = - divide start_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG (73)

so

g2(t)superscript𝑔2𝑡\displaystyle g^{2}(t)italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) =dσt2dt2dlogαtdtσt2absentdsuperscriptsubscript𝜎𝑡2d𝑡2dsubscript𝛼𝑡d𝑡superscriptsubscript𝜎𝑡2\displaystyle=\frac{\mathrm{d}\sigma_{t}^{2}}{\mathrm{d}t}-2\frac{\mathrm{d}% \log\alpha_{t}}{\mathrm{d}t}\sigma_{t}^{2}= divide start_ARG roman_d italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG roman_d italic_t end_ARG - 2 divide start_ARG roman_d roman_log italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG roman_d italic_t end_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (74)
=2σtσ˙t+2σtσ˙tαt2σt2absent2subscript𝜎𝑡subscript˙𝜎𝑡2subscript𝜎𝑡subscript˙𝜎𝑡superscriptsubscript𝛼𝑡2superscriptsubscript𝜎𝑡2\displaystyle=2\sigma_{t}\dot{\sigma}_{t}+2\frac{\sigma_{t}\dot{\sigma}_{t}}{% \alpha_{t}^{2}}\sigma_{t}^{2}= 2 italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + 2 divide start_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
=2σtσ˙tαt2(αt2+σt2)absent2subscript𝜎𝑡subscript˙𝜎𝑡superscriptsubscript𝛼𝑡2superscriptsubscript𝛼𝑡2superscriptsubscript𝜎𝑡2\displaystyle=\frac{2\sigma_{t}\dot{\sigma}_{t}}{\alpha_{t}^{2}}(\alpha_{t}^{2% }+\sigma_{t}^{2})= divide start_ARG 2 italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ( italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )
=2σtσ˙tαt2absent2subscript𝜎𝑡subscript˙𝜎𝑡superscriptsubscript𝛼𝑡2\displaystyle=\frac{2\sigma_{t}\dot{\sigma}_{t}}{\alpha_{t}^{2}}= divide start_ARG 2 italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG

and the velocity is

𝒗𝒗\displaystyle\bm{v}bold_italic_v =α˙t𝒙0+σ˙tϵabsentsubscript˙𝛼𝑡subscript𝒙0subscript˙𝜎𝑡bold-italic-ϵ\displaystyle=\dot{\alpha}_{t}\bm{x}_{0}+\dot{\sigma}_{t}\bm{\epsilon}= over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_ϵ (75)
=σ˙tϵσtσ˙tαt𝒙0absentsubscript˙𝜎𝑡bold-italic-ϵsubscript𝜎𝑡subscript˙𝜎𝑡subscript𝛼𝑡subscript𝒙0\displaystyle=\dot{\sigma}_{t}\bm{\epsilon}-\frac{\sigma_{t}\dot{\sigma}_{t}}{% \alpha_{t}}\bm{x}_{0}= over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_ϵ - divide start_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT
=σ˙tαt(αtϵσt𝒙0)absentsubscript˙𝜎𝑡subscript𝛼𝑡subscript𝛼𝑡bold-italic-ϵsubscript𝜎𝑡subscript𝒙0\displaystyle=\frac{\dot{\sigma}_{t}}{\alpha_{t}}(\alpha_{t}\bm{\epsilon}-% \sigma_{t}\bm{x}_{0})= divide start_ARG over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG ( italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_ϵ - italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT )
=αt2σtg2(t)(αtϵσt𝒙0)absentsubscript𝛼𝑡2subscript𝜎𝑡superscript𝑔2𝑡subscript𝛼𝑡bold-italic-ϵsubscript𝜎𝑡subscript𝒙0\displaystyle=\frac{\alpha_{t}}{2\sigma_{t}}g^{2}(t)(\alpha_{t}\bm{\epsilon}-% \sigma_{t}\bm{x}_{0})= divide start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG 2 italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) ( italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_ϵ - italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT )

Besides, we can compute the normalizing factor as

α˙t2+σ˙t2=σt2σ˙t2αt2+σ˙t2=σ˙tαt=αt2σtg2(t)superscriptsubscript˙𝛼𝑡2superscriptsubscript˙𝜎𝑡2superscriptsubscript𝜎𝑡2superscriptsubscript˙𝜎𝑡2superscriptsubscript𝛼𝑡2superscriptsubscript˙𝜎𝑡2subscript˙𝜎𝑡subscript𝛼𝑡subscript𝛼𝑡2subscript𝜎𝑡superscript𝑔2𝑡\sqrt{\dot{\alpha}_{t}^{2}+\dot{\sigma}_{t}^{2}}=\sqrt{\frac{\sigma_{t}^{2}% \dot{\sigma}_{t}^{2}}{\alpha_{t}^{2}}+\dot{\sigma}_{t}^{2}}=\frac{\dot{\sigma}% _{t}}{\alpha_{t}}=\frac{\alpha_{t}}{2\sigma_{t}}g^{2}(t)square-root start_ARG over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG = square-root start_ARG divide start_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG = divide start_ARG over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG = divide start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG 2 italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) (76)

so we have the normalized velocity

𝒗~=𝒗α˙t2+σ˙t2=αtϵσt𝒙0~𝒗𝒗superscriptsubscript˙𝛼𝑡2superscriptsubscript˙𝜎𝑡2subscript𝛼𝑡bold-italic-ϵsubscript𝜎𝑡subscript𝒙0\tilde{\bm{v}}=\frac{\bm{v}}{\sqrt{\dot{\alpha}_{t}^{2}+\dot{\sigma}_{t}^{2}}}% =\alpha_{t}\bm{\epsilon}-\sigma_{t}\bm{x}_{0}over~ start_ARG bold_italic_v end_ARG = divide start_ARG bold_italic_v end_ARG start_ARG square-root start_ARG over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG end_ARG = italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_ϵ - italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT (77)

Therefore, 𝐯=𝒗~𝐯~𝒗\mathbf{v}=\tilde{\bm{v}}bold_v = over~ start_ARG bold_italic_v end_ARG, which means that v prediction is a special case of velocity parameterization when the noise schedule is VP.

Appendix F Error-bounded trace of second-order flow matching

Here we provide the proofs for the error-bounded trace of second-order flow matching. First, we provide a lemma that gives the Jacobian of the ground-truth velocity predictor 𝒗*(𝒙t,t)superscript𝒗subscript𝒙𝑡𝑡\bm{v}^{*}(\bm{x}_{t},t)bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ).

Lemma F.1.

Suppose (𝐱0,𝐱t)q(𝐱0,𝐱t)similar-tosubscript𝐱0subscript𝐱𝑡𝑞subscript𝐱0subscript𝐱𝑡(\bm{x}_{0},\bm{x}_{t})\sim q(\bm{x}_{0},\bm{x}_{t})( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ∼ italic_q ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ), denote 𝐱t=αt𝐱0+σtϵ,𝐯=α˙t𝐱0+σ˙tϵ,()=𝐱t()formulae-sequencesubscript𝐱𝑡subscript𝛼𝑡subscript𝐱0subscript𝜎𝑡bold-ϵformulae-sequence𝐯subscriptnormal-˙𝛼𝑡subscript𝐱0subscriptnormal-˙𝜎𝑡bold-ϵnormal-∇normal-⋅subscriptnormal-∇subscript𝐱𝑡normal-⋅\bm{x}_{t}=\alpha_{t}\bm{x}_{0}+\sigma_{t}\bm{\epsilon},\bm{v}=\dot{\alpha}_{t% }\bm{x}_{0}+\dot{\sigma}_{t}\bm{\epsilon},\nabla(\cdot)=\nabla_{\bm{x}_{t}}(\cdot)bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_ϵ , bold_italic_v = over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_ϵ , ∇ ( ⋅ ) = ∇ start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( ⋅ ), we have

𝒗*(𝒙t,t)=σ˙tσt𝑰2g2(t)𝔼qt0(𝒙0|𝒙t)[(𝒗*(𝒙t,t)𝒗)(𝒗*(𝒙t,t)𝒗)]superscript𝒗subscript𝒙𝑡𝑡subscript˙𝜎𝑡subscript𝜎𝑡𝑰2superscript𝑔2𝑡subscript𝔼subscript𝑞𝑡0conditionalsubscript𝒙0subscript𝒙𝑡delimited-[]superscript𝒗subscript𝒙𝑡𝑡𝒗superscriptsuperscript𝒗subscript𝒙𝑡𝑡𝒗top\nabla\bm{v}^{*}(\bm{x}_{t},t)=\frac{\dot{\sigma}_{t}}{\sigma_{t}}\bm{I}-\frac% {2}{g^{2}(t)}\mathbb{E}_{q_{t0}(\bm{x}_{0}|\bm{x}_{t})}\left[(\bm{v}^{*}(\bm{x% }_{t},t)-\bm{v})(\bm{v}^{*}(\bm{x}_{t},t)-\bm{v})^{\top}\right]∇ bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) = divide start_ARG over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG bold_italic_I - divide start_ARG 2 end_ARG start_ARG italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) end_ARG blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_t 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ ( bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - bold_italic_v ) ( bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - bold_italic_v ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] (78)

and

tr(𝒗*(𝒙t,t))=σ˙tσtd2g2(t)𝔼qt0(𝒙0|𝒙t)[𝒗*(𝒙t,t)𝒗22]trsuperscript𝒗subscript𝒙𝑡𝑡subscript˙𝜎𝑡subscript𝜎𝑡𝑑2superscript𝑔2𝑡subscript𝔼subscript𝑞𝑡0conditionalsubscript𝒙0subscript𝒙𝑡delimited-[]superscriptsubscriptnormsuperscript𝒗subscript𝒙𝑡𝑡𝒗22\mathrm{tr}(\nabla\bm{v}^{*}(\bm{x}_{t},t))=\frac{\dot{\sigma}_{t}}{\sigma_{t}% }d-\frac{2}{g^{2}(t)}\mathbb{E}_{q_{t0}(\bm{x}_{0}|\bm{x}_{t})}\left[\|\bm{v}^% {*}(\bm{x}_{t},t)-\bm{v}\|_{2}^{2}\right]roman_tr ( ∇ bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) ) = divide start_ARG over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG italic_d - divide start_ARG 2 end_ARG start_ARG italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) end_ARG blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_t 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ ∥ bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - bold_italic_v ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] (79)
Proof.

First, the gradient of qt0subscript𝑞𝑡0q_{t0}italic_q start_POSTSUBSCRIPT italic_t 0 end_POSTSUBSCRIPT can be calculated as

qt0(𝒙0|𝒙t)subscript𝑞𝑡0conditionalsubscript𝒙0subscript𝒙𝑡\displaystyle\nabla q_{t0}(\bm{x}_{0}|\bm{x}_{t})∇ italic_q start_POSTSUBSCRIPT italic_t 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) =q0(𝒙0)q0t(𝒙t|𝒙0)qt(𝒙t)absentsubscript𝑞0subscript𝒙0subscript𝑞0𝑡conditionalsubscript𝒙𝑡subscript𝒙0subscript𝑞𝑡subscript𝒙𝑡\displaystyle=\nabla\frac{q_{0}(\bm{x}_{0})q_{0t}(\bm{x}_{t}|\bm{x}_{0})}{q_{t% }(\bm{x}_{t})}= ∇ divide start_ARG italic_q start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_q start_POSTSUBSCRIPT 0 italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_ARG start_ARG italic_q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG (80)
=q0(𝒙0)qt(𝒙t)q0t(𝒙t|𝒙0)q0t(𝒙t|𝒙0)qt(𝒙t)qt(𝒙t)2absentsubscript𝑞0subscript𝒙0subscript𝑞𝑡subscript𝒙𝑡subscript𝑞0𝑡conditionalsubscript𝒙𝑡subscript𝒙0subscript𝑞0𝑡conditionalsubscript𝒙𝑡subscript𝒙0subscript𝑞𝑡subscript𝒙𝑡subscript𝑞𝑡superscriptsubscript𝒙𝑡2\displaystyle=q_{0}(\bm{x}_{0})\frac{q_{t}(\bm{x}_{t})\nabla q_{0t}(\bm{x}_{t}% |\bm{x}_{0})-q_{0t}(\bm{x}_{t}|\bm{x}_{0})\nabla q_{t}(\bm{x}_{t})}{q_{t}(\bm{% x}_{t})^{2}}= italic_q start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) divide start_ARG italic_q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ∇ italic_q start_POSTSUBSCRIPT 0 italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) - italic_q start_POSTSUBSCRIPT 0 italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ∇ italic_q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG start_ARG italic_q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG
=q0(𝒙0)q0t(𝒙t|𝒙0)qt(𝒙t)(logq0t(𝒙t|𝒙0)logqt(𝒙t))absentsubscript𝑞0subscript𝒙0subscript𝑞0𝑡conditionalsubscript𝒙𝑡subscript𝒙0subscript𝑞𝑡subscript𝒙𝑡subscript𝑞0𝑡conditionalsubscript𝒙𝑡subscript𝒙0subscript𝑞𝑡subscript𝒙𝑡\displaystyle=\frac{q_{0}(\bm{x}_{0})q_{0t}(\bm{x}_{t}|\bm{x}_{0})}{q_{t}(\bm{% x}_{t})}\left(\nabla\log q_{0t}(\bm{x}_{t}|\bm{x}_{0})-\nabla\log q_{t}(\bm{x}% _{t})\right)= divide start_ARG italic_q start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_q start_POSTSUBSCRIPT 0 italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_ARG start_ARG italic_q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG ( ∇ roman_log italic_q start_POSTSUBSCRIPT 0 italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) - ∇ roman_log italic_q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) )
=qt0(𝒙0|𝒙t)(logq0t(𝒙t|𝒙0)logqt(𝒙t))absentsubscript𝑞𝑡0conditionalsubscript𝒙0subscript𝒙𝑡subscript𝑞0𝑡conditionalsubscript𝒙𝑡subscript𝒙0subscript𝑞𝑡subscript𝒙𝑡\displaystyle=q_{t0}(\bm{x}_{0}|\bm{x}_{t})\left(\nabla\log q_{0t}(\bm{x}_{t}|% \bm{x}_{0})-\nabla\log q_{t}(\bm{x}_{t})\right)= italic_q start_POSTSUBSCRIPT italic_t 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ( ∇ roman_log italic_q start_POSTSUBSCRIPT 0 italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) - ∇ roman_log italic_q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) )
=2g2(t)(𝒗*(𝒙t,t)𝒗)qt0(𝒙0|𝒙t)absent2superscript𝑔2𝑡superscript𝒗subscript𝒙𝑡𝑡𝒗subscript𝑞𝑡0conditionalsubscript𝒙0subscript𝒙𝑡\displaystyle=\frac{2}{g^{2}(t)}(\bm{v}^{*}(\bm{x}_{t},t)-\bm{v})q_{t0}(\bm{x}% _{0}|\bm{x}_{t})= divide start_ARG 2 end_ARG start_ARG italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) end_ARG ( bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - bold_italic_v ) italic_q start_POSTSUBSCRIPT italic_t 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT )

where we use the relation between 𝒗*(𝒙t,t)superscript𝒗subscript𝒙𝑡𝑡\bm{v}^{*}(\bm{x}_{t},t)bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) and logqt(𝒙t)subscript𝑞𝑡subscript𝒙𝑡\nabla\log q_{t}(\bm{x}_{t})∇ roman_log italic_q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) in Theorem B.1. From Eqn. (48), we know 𝒗*(𝒙t,t)=𝔼qt0(𝒙0|𝒙t)[𝒗]superscript𝒗subscript𝒙𝑡𝑡subscript𝔼subscript𝑞𝑡0conditionalsubscript𝒙0subscript𝒙𝑡delimited-[]𝒗\bm{v}^{*}(\bm{x}_{t},t)=\mathbb{E}_{q_{t0}(\bm{x}_{0}|\bm{x}_{t})}[\bm{v}]bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) = blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_t 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ bold_italic_v ], and for given 𝒙0subscript𝒙0\bm{x}_{0}bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, we have

𝒗=(α˙t𝒙0+σ˙t𝒙tαt𝒙0σt)=σ˙tσt𝑰𝒗subscript˙𝛼𝑡subscript𝒙0subscript˙𝜎𝑡subscript𝒙𝑡subscript𝛼𝑡subscript𝒙0subscript𝜎𝑡subscript˙𝜎𝑡subscript𝜎𝑡𝑰\nabla\bm{v}=\nabla\left(\dot{\alpha}_{t}\bm{x}_{0}+\dot{\sigma}_{t}\frac{\bm{% x}_{t}-\alpha_{t}\bm{x}_{0}}{\sigma_{t}}\right)=\frac{\dot{\sigma}_{t}}{\sigma% _{t}}\bm{I}∇ bold_italic_v = ∇ ( over˙ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT divide start_ARG bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG ) = divide start_ARG over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG bold_italic_I (81)

and

𝔼qt0(𝒙0|𝒙t)[(𝒗*(𝒙t,t)𝒗)𝒗*(𝒙t,t)]=𝔼qt0(𝒙0|𝒙t)[𝒗*(𝒙t,t)𝒗]𝒗*(𝒙t,t)=𝟎subscript𝔼subscript𝑞𝑡0conditionalsubscript𝒙0subscript𝒙𝑡delimited-[]superscript𝒗subscript𝒙𝑡𝑡𝒗superscript𝒗superscriptsubscript𝒙𝑡𝑡topsubscript𝔼subscript𝑞𝑡0conditionalsubscript𝒙0subscript𝒙𝑡delimited-[]superscript𝒗subscript𝒙𝑡𝑡𝒗superscript𝒗superscriptsubscript𝒙𝑡𝑡top0\mathbb{E}_{q_{t0}(\bm{x}_{0}|\bm{x}_{t})}\left[(\bm{v}^{*}(\bm{x}_{t},t)-\bm{% v})\bm{v}^{*}(\bm{x}_{t},t)^{\top}\right]=\mathbb{E}_{q_{t0}(\bm{x}_{0}|\bm{x}% _{t})}\left[\bm{v}^{*}(\bm{x}_{t},t)-\bm{v}\right]\bm{v}^{*}(\bm{x}_{t},t)^{% \top}=\bm{0}blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_t 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ ( bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - bold_italic_v ) bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] = blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_t 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - bold_italic_v ] bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT = bold_0 (82)

So

𝒗*(𝒙t,t)superscript𝒗subscript𝒙𝑡𝑡\displaystyle\nabla\bm{v}^{*}(\bm{x}_{t},t)∇ bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) =qt0(𝒙0|𝒙t)𝒗d𝒙0absentsubscript𝑞𝑡0conditionalsubscript𝒙0subscript𝒙𝑡𝒗differential-dsubscript𝒙0\displaystyle=\nabla\int q_{t0}(\bm{x}_{0}|\bm{x}_{t})\bm{v}\mathrm{d}\bm{x}_{0}= ∇ ∫ italic_q start_POSTSUBSCRIPT italic_t 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) bold_italic_v roman_d bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT (83)
=qt0(𝒙0|𝒙t)𝒗+qt0(𝒙0|𝒙t)𝒗d𝒙0absentsubscript𝑞𝑡0conditionalsubscript𝒙0subscript𝒙𝑡superscript𝒗topsubscript𝑞𝑡0conditionalsubscript𝒙0subscript𝒙𝑡𝒗dsubscript𝒙0\displaystyle=\int\nabla q_{t0}(\bm{x}_{0}|\bm{x}_{t})\bm{v}^{\top}+q_{t0}(\bm% {x}_{0}|\bm{x}_{t})\nabla\bm{v}\mathrm{d}\bm{x}_{0}= ∫ ∇ italic_q start_POSTSUBSCRIPT italic_t 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) bold_italic_v start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + italic_q start_POSTSUBSCRIPT italic_t 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ∇ bold_italic_v roman_d bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT
=qt0(𝒙0|𝒙t)(2g2(t)(𝒗*(𝒙t,t)𝒗)𝒗+σ˙tσt𝑰)d𝒙0absentsubscript𝑞𝑡0conditionalsubscript𝒙0subscript𝒙𝑡2superscript𝑔2𝑡superscript𝒗subscript𝒙𝑡𝑡𝒗superscript𝒗topsubscript˙𝜎𝑡subscript𝜎𝑡𝑰differential-dsubscript𝒙0\displaystyle=\int q_{t0}(\bm{x}_{0}|\bm{x}_{t})\left(\frac{2}{g^{2}(t)}(\bm{v% }^{*}(\bm{x}_{t},t)-\bm{v})\bm{v}^{\top}+\frac{\dot{\sigma}_{t}}{\sigma_{t}}% \bm{I}\right)\mathrm{d}\bm{x}_{0}= ∫ italic_q start_POSTSUBSCRIPT italic_t 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ( divide start_ARG 2 end_ARG start_ARG italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) end_ARG ( bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - bold_italic_v ) bold_italic_v start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + divide start_ARG over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG bold_italic_I ) roman_d bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT
=σ˙tσt𝑰+2g2(t)𝔼qt0(𝒙0|𝒙t)[(𝒗*(𝒙t,t)𝒗)𝒗]absentsubscript˙𝜎𝑡subscript𝜎𝑡𝑰2superscript𝑔2𝑡subscript𝔼subscript𝑞𝑡0conditionalsubscript𝒙0subscript𝒙𝑡delimited-[]superscript𝒗subscript𝒙𝑡𝑡𝒗superscript𝒗top\displaystyle=\frac{\dot{\sigma}_{t}}{\sigma_{t}}\bm{I}+\frac{2}{g^{2}(t)}% \mathbb{E}_{q_{t0}(\bm{x}_{0}|\bm{x}_{t})}\left[(\bm{v}^{*}(\bm{x}_{t},t)-\bm{% v})\bm{v}^{\top}\right]= divide start_ARG over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG bold_italic_I + divide start_ARG 2 end_ARG start_ARG italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) end_ARG blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_t 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ ( bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - bold_italic_v ) bold_italic_v start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ]
=σ˙tσt𝑰2g2(t)𝔼qt0(𝒙0|𝒙t)[(𝒗*(𝒙t,t)𝒗)(𝒗*(𝒙t,t)𝒗)]absentsubscript˙𝜎𝑡subscript𝜎𝑡𝑰2superscript𝑔2𝑡subscript𝔼subscript𝑞𝑡0conditionalsubscript𝒙0subscript𝒙𝑡delimited-[]superscript𝒗subscript𝒙𝑡𝑡𝒗superscriptsuperscript𝒗subscript𝒙𝑡𝑡𝒗top\displaystyle=\frac{\dot{\sigma}_{t}}{\sigma_{t}}\bm{I}-\frac{2}{g^{2}(t)}% \mathbb{E}_{q_{t0}(\bm{x}_{0}|\bm{x}_{t})}\left[(\bm{v}^{*}(\bm{x}_{t},t)-\bm{% v})(\bm{v}^{*}(\bm{x}_{t},t)-\bm{v})^{\top}\right]= divide start_ARG over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG bold_italic_I - divide start_ARG 2 end_ARG start_ARG italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) end_ARG blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_t 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ ( bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - bold_italic_v ) ( bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - bold_italic_v ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ]

The expression for tr(𝒗*(𝒙t,t))trsuperscript𝒗subscript𝒙𝑡𝑡\mathrm{tr}(\nabla\bm{v}^{*}(\bm{x}_{t},t))roman_tr ( ∇ bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) ) can be easily derived from the above equation. ∎

Then we prove Theorem 4.1 as follows.

Proof.

The optimization in Eqn. (15) can be rewritten as

θ*=argminθ2σt2g2(t)𝔼qt(𝒙t)𝔼qt0(𝒙0|𝒙t)[|𝒗2trace(𝒙t,t;θ)σ˙tσtd+2g2(t)𝒗^1(𝒙t,t)𝒗22|2].superscript𝜃subscriptargmin𝜃2superscriptsubscript𝜎𝑡2superscript𝑔2𝑡subscript𝔼subscript𝑞𝑡subscript𝒙𝑡subscript𝔼subscript𝑞𝑡0conditionalsubscript𝒙0subscript𝒙𝑡delimited-[]superscriptsuperscriptsubscript𝒗2tracesubscript𝒙𝑡𝑡𝜃subscript˙𝜎𝑡subscript𝜎𝑡𝑑2superscript𝑔2𝑡superscriptsubscriptnormsubscript^𝒗1subscript𝒙𝑡𝑡𝒗222\theta^{*}\!=\operatornamewithlimits{argmin}_{\theta}\frac{2\sigma_{t}^{2}}{g^% {2}(t)}\mathbb{E}_{q_{t}(\bm{x}_{t})}\mathbb{E}_{q_{t0}(\bm{x}_{0}|\bm{x}_{t})% }\!\left[\left|\bm{v}_{2}^{\text{trace}}(\bm{x}_{t},t;\theta)\!-\!\frac{\dot{% \sigma}_{t}}{\sigma_{t}}d\!+\!\frac{2}{g^{2}(t)}\|\hat{\bm{v}}_{1}(\bm{x}_{t},% t)-\bm{v}\|_{2}^{2}\right|^{2}\right].italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT = roman_argmin start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT divide start_ARG 2 italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) end_ARG blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_t 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ | bold_italic_v start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT trace end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ; italic_θ ) - divide start_ARG over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG italic_d + divide start_ARG 2 end_ARG start_ARG italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) end_ARG ∥ over^ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - bold_italic_v ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] . (84)

For fixed t𝑡titalic_t and 𝒙tsubscript𝒙𝑡\bm{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, minimizing the inner expectation is a minimum mean square error problem for 𝒗2trace(𝒙t,t;θ)superscriptsubscript𝒗2tracesubscript𝒙𝑡𝑡𝜃\bm{v}_{2}^{\text{trace}}(\bm{x}_{t},t;\theta)bold_italic_v start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT trace end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ; italic_θ ), so the optimal θ*superscript𝜃\theta^{*}italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT satisfies

𝒗2trace(𝒙t,t;θ*)=σ˙tσtd2g2(t)𝔼qt0(𝒙0|𝒙t)[𝒗^1(𝒙t,t)𝒗22]superscriptsubscript𝒗2tracesubscript𝒙𝑡𝑡superscript𝜃subscript˙𝜎𝑡subscript𝜎𝑡𝑑2superscript𝑔2𝑡subscript𝔼subscript𝑞𝑡0conditionalsubscript𝒙0subscript𝒙𝑡delimited-[]superscriptsubscriptnormsubscript^𝒗1subscript𝒙𝑡𝑡𝒗22\bm{v}_{2}^{\text{trace}}(\bm{x}_{t},t;\theta^{*})=\frac{\dot{\sigma}_{t}}{% \sigma_{t}}d-\frac{2}{g^{2}(t)}\mathbb{E}_{q_{t0}(\bm{x}_{0}|\bm{x}_{t})}[\|% \hat{\bm{v}}_{1}(\bm{x}_{t},t)-\bm{v}\|_{2}^{2}]bold_italic_v start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT trace end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ; italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) = divide start_ARG over˙ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG italic_d - divide start_ARG 2 end_ARG start_ARG italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) end_ARG blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_t 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ ∥ over^ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - bold_italic_v ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] (85)

Using Lemma 79 and 𝒗*(𝒙t,t)=𝔼qt0(𝒙0|𝒙t)[𝒗]superscript𝒗subscript𝒙𝑡𝑡subscript𝔼subscript𝑞𝑡0conditionalsubscript𝒙0subscript𝒙𝑡delimited-[]𝒗\bm{v}^{*}(\bm{x}_{t},t)=\mathbb{E}_{q_{t0}(\bm{x}_{0}|\bm{x}_{t})}[\bm{v}]bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) = blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_t 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ bold_italic_v ], we have

tr(𝒗*(𝒙t,t))𝒗2trace(𝒙t,t;θ*)trsuperscript𝒗subscript𝒙𝑡𝑡superscriptsubscript𝒗2tracesubscript𝒙𝑡𝑡superscript𝜃\displaystyle\mathrm{tr}(\nabla\bm{v}^{*}(\bm{x}_{t},t))-\bm{v}_{2}^{\text{% trace}}(\bm{x}_{t},t;\theta^{*})roman_tr ( ∇ bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) ) - bold_italic_v start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT trace end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ; italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) (86)
=\displaystyle== 2g2(t)𝔼qt0(𝒙0|𝒙t)[𝒗^1(𝒙t,t)222𝒗𝒗^1(𝒙t,t)𝒗*(𝒙t,t)22+2𝒗𝒗*(𝒙t,t)]2superscript𝑔2𝑡subscript𝔼subscript𝑞𝑡0conditionalsubscript𝒙0subscript𝒙𝑡delimited-[]superscriptsubscriptnormsubscript^𝒗1subscript𝒙𝑡𝑡222superscript𝒗topsubscript^𝒗1subscript𝒙𝑡𝑡superscriptsubscriptnormsuperscript𝒗subscript𝒙𝑡𝑡222superscript𝒗topsuperscript𝒗subscript𝒙𝑡𝑡\displaystyle\frac{2}{g^{2}(t)}\mathbb{E}_{q_{t0}(\bm{x}_{0}|\bm{x}_{t})}\left% [\|\hat{\bm{v}}_{1}(\bm{x}_{t},t)\|_{2}^{2}-2\bm{v}^{\top}\hat{\bm{v}}_{1}(\bm% {x}_{t},t)-\|\bm{v}^{*}(\bm{x}_{t},t)\|_{2}^{2}+2\bm{v}^{\top}\bm{v}^{*}(\bm{x% }_{t},t)\right]divide start_ARG 2 end_ARG start_ARG italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) end_ARG blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_t 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ ∥ over^ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - 2 bold_italic_v start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT over^ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - ∥ bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + 2 bold_italic_v start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) ]
=\displaystyle== 2g2(t)𝒗^1(𝒙t,t)𝒗*(𝒙t,t)222superscript𝑔2𝑡superscriptsubscriptnormsubscript^𝒗1subscript𝒙𝑡𝑡superscript𝒗subscript𝒙𝑡𝑡22\displaystyle\frac{2}{g^{2}(t)}\|\hat{\bm{v}}_{1}(\bm{x}_{t},t)-\bm{v}^{*}(\bm% {x}_{t},t)\|_{2}^{2}divide start_ARG 2 end_ARG start_ARG italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) end_ARG ∥ over^ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

Therefore, we can obtain the error bound by

|𝒗2trace(𝒙t,t;θ)tr(𝒙𝒗*(𝒙t,t))|superscriptsubscript𝒗2tracesubscript𝒙𝑡𝑡𝜃trsubscript𝒙superscript𝒗subscript𝒙𝑡𝑡\displaystyle\left|\bm{v}_{2}^{\text{trace}}(\bm{x}_{t},t;\theta)-\mathrm{tr}(% \nabla_{\bm{x}}\bm{v}^{*}(\bm{x}_{t},t))\right|| bold_italic_v start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT trace end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ; italic_θ ) - roman_tr ( ∇ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) ) | |𝒗2trace(𝒙t,t;θ)𝒗2trace(𝒙t,t;θ*)|+|𝒗2trace(𝒙t,t;θ*)tr(𝒗*(𝒙t,t))|absentsuperscriptsubscript𝒗2tracesubscript𝒙𝑡𝑡𝜃superscriptsubscript𝒗2tracesubscript𝒙𝑡𝑡superscript𝜃superscriptsubscript𝒗2tracesubscript𝒙𝑡𝑡superscript𝜃trsuperscript𝒗subscript𝒙𝑡𝑡\displaystyle\leq\left|\bm{v}_{2}^{\text{trace}}(\bm{x}_{t},t;\theta)-\bm{v}_{% 2}^{\text{trace}}(\bm{x}_{t},t;\theta^{*})\right|+|\bm{v}_{2}^{\text{trace}}(% \bm{x}_{t},t;\theta^{*})-\mathrm{tr}(\nabla\bm{v}^{*}(\bm{x}_{t},t))|≤ | bold_italic_v start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT trace end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ; italic_θ ) - bold_italic_v start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT trace end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ; italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) | + | bold_italic_v start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT trace end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ; italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) - roman_tr ( ∇ bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) ) | (87)
=|𝒗2trace(𝒙t,t;θ)𝒗2trace(𝒙t,t;θ*)|+2g2(t)δ12(𝒙t,t)absentsuperscriptsubscript𝒗2tracesubscript𝒙𝑡𝑡𝜃superscriptsubscript𝒗2tracesubscript𝒙𝑡𝑡superscript𝜃2superscript𝑔2𝑡superscriptsubscript𝛿12subscript𝒙𝑡𝑡\displaystyle=\left|\bm{v}_{2}^{\text{trace}}(\bm{x}_{t},t;\theta)-\bm{v}_{2}^% {\text{trace}}(\bm{x}_{t},t;\theta^{*})\right|+\frac{2}{g^{2}(t)}\delta_{1}^{2% }(\bm{x}_{t},t)= | bold_italic_v start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT trace end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ; italic_θ ) - bold_italic_v start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT trace end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ; italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) | + divide start_ARG 2 end_ARG start_ARG italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) end_ARG italic_δ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t )

where δ1(𝒙t,t)=𝒗^1(𝒙t,t)𝒗*(𝒙t,t)2subscript𝛿1subscript𝒙𝑡𝑡subscriptnormsubscript^𝒗1subscript𝒙𝑡𝑡superscript𝒗subscript𝒙𝑡𝑡2\delta_{1}(\bm{x}_{t},t)=\|\hat{\bm{v}}_{1}(\bm{x}_{t},t)-\bm{v}^{*}(\bm{x}_{t% },t)\|_{2}italic_δ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) = ∥ over^ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - bold_italic_v start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT is the first-order estimation error. ∎

Appendix G Difference between our second-order flow matching and the previous time score matching in Choi et al. (2022)

We propose the error-bounded second-order flow matching objective to regularize tr(𝒙𝒗θ(𝒙t,t))trsubscript𝒙subscript𝒗𝜃subscript𝒙𝑡𝑡-\mathrm{tr}(\nabla_{\bm{x}}\bm{v}_{\theta}(\bm{x}_{t},t))- roman_tr ( ∇ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) ), which is equal to dlogpt(𝒙t)dtdsubscript𝑝𝑡subscript𝒙𝑡d𝑡\frac{\mathrm{d}\log p_{t}(\bm{x}_{t})}{\mathrm{d}t}divide start_ARG roman_d roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG start_ARG roman_d italic_t end_ARG by the “Instantaneous Change of Variables” formula of CNFs (Chen et al., 2018a). Choi et al. (2022) proposes a joint score matching method to estimate the data score as well as the time score (𝒙logpt(𝒙),tlogpt(𝒙))subscript𝒙subscript𝑝𝑡𝒙subscript𝑡subscript𝑝𝑡𝒙(\nabla_{\bm{x}}\log p_{t}(\bm{x}),\partial_{t}\log p_{t}(\bm{x}))( ∇ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x ) , ∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x ) ), which seems related. However, they are essentially different.

Firstly, the change-of-variable for CNFs describes the total derivative of logpt(𝒙t)subscript𝑝𝑡subscript𝒙𝑡\log p_{t}(\bm{x}_{t})roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) w.r.t. 𝒙tsubscript𝒙𝑡\bm{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT which evolves by the ODE flow trajectory, not each fixed data point 𝒙d𝒙superscript𝑑\bm{x}\in\mathbb{R}^{d}bold_italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT. However, tlogpt(𝒙)subscript𝑡subscript𝑝𝑡𝒙\partial_{t}\log p_{t}(\bm{x})∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x ) in Choi et al. (2022) describes the partial derivative of logpt(𝒙)subscript𝑝𝑡𝒙\log p_{t}(\bm{x})roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x ) for 𝒙d𝒙superscript𝑑\bm{x}\in\mathbb{R}^{d}bold_italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT, i.e., any fixed data point in the whole space. Specifically, according to the Fokker-Planck equation, we have

tpt(𝒙)=𝒙(pt(𝒙)𝒗θ(𝒙,t))subscript𝑡subscript𝑝𝑡𝒙subscript𝒙subscript𝑝𝑡𝒙subscript𝒗𝜃𝒙𝑡\partial_{t}p_{t}(\bm{x})=-\nabla_{\bm{x}}\cdot(p_{t}(\bm{x})\bm{v}_{\theta}(% \bm{x},t))∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x ) = - ∇ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT ⋅ ( italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x ) bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x , italic_t ) ) (88)

It follows that

tlogpt(𝒙)=𝒙𝒗θ(𝒙,t)𝒗θ(𝒙,t)𝒙logpt(𝒙)subscript𝑡subscript𝑝𝑡𝒙subscript𝒙subscript𝒗𝜃𝒙𝑡subscript𝒗𝜃superscript𝒙𝑡topsubscript𝒙subscript𝑝𝑡𝒙\partial_{t}\log p_{t}(\bm{x})=-\nabla_{\bm{x}}\cdot\bm{v}_{\theta}(\bm{x},t)-% \bm{v}_{\theta}(\bm{x},t)^{\top}\nabla_{\bm{x}}\log p_{t}(\bm{x})∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x ) = - ∇ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT ⋅ bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x , italic_t ) - bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x , italic_t ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x ) (89)

Therefore, the total derivative dlogpt(𝒙t)dtdsubscript𝑝𝑡subscript𝒙𝑡d𝑡\frac{\mathrm{d}\log p_{t}(\bm{x}_{t})}{\mathrm{d}t}divide start_ARG roman_d roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG start_ARG roman_d italic_t end_ARG we care about is different from the partial derivative tlogpt(x)subscript𝑡subscript𝑝𝑡𝑥\partial_{t}\log p_{t}(x)∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) in Choi et al. (2022), and their training objectives are also different (with different optimal solutions).

Moreover, there is another difference: Choi et al. (2022) trains another model to estimate the partial derivative tlogpt(x)subscript𝑡subscript𝑝𝑡𝑥\partial_{t}\log p_{t}(x)∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ), which is independent of the ODE velocity 𝒗θ(𝒙,t)subscript𝒗𝜃𝒙𝑡\bm{v}_{\theta}(\bm{x},t)bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x , italic_t ) (in the form of the parameterized score function 𝒔θ(𝒙,t)subscript𝒔𝜃𝒙𝑡\bm{s}_{\theta}(\bm{x},t)bold_italic_s start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x , italic_t )). However, our method restricts the parameterized velocity 𝒗θ(𝒙,t)subscript𝒗𝜃𝒙𝑡\bm{v}_{\theta}(\bm{x},t)bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x , italic_t ) itself, and does not employ another model.

Finally, the techniques used in Choi et al. (2022) and our work are also different. Choi et al. (2022) estimates the score matching loss for the partial derivative tlogpt(x)subscript𝑡subscript𝑝𝑡𝑥\partial_{t}\log p_{t}(x)∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) by the well-known integral-by-parts, which is used to derive the famous sliced score matching (Song et al., 2020), to avoid the computation of the score function 𝒙logpt(𝒙)subscript𝒙subscript𝑝𝑡𝒙\nabla_{\bm{x}}\log p_{t}(\bm{x})∇ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x ); However, our method leverages the property of mean square error (that its minimum is conditional mean), which is used to derive the famous denoising score matching (Vincent, 2011), to estimate the divergence of 𝒗θ(𝒙,t)subscript𝒗𝜃𝒙𝑡\bm{v}_{\theta}(\bm{x},t)bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x , italic_t ). In the score matching literature, sliced score matching and denoising score matching are two rather different techniques. As first-order denoising score matching is widely used in training diffusion models (such as Song et al. (2021c)), our proposed second-order flow matching is also suitable for training diffusion ODEs.

Appendix H Details of our adaptive IS

In this section, we give details of our adaptive IS stated in Section 4.4. First, we parameterize γη(t)subscript𝛾𝜂𝑡\gamma_{\eta}(t)italic_γ start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT ( italic_t ) similar to Kingma et al. (2021):

γη(t)=γ0+(γTγ0)γ~η(t)γ~η(0)γ~η(1)γ~η(0)subscript𝛾𝜂𝑡subscript𝛾0subscript𝛾𝑇subscript𝛾0subscript~𝛾𝜂𝑡subscript~𝛾𝜂0subscript~𝛾𝜂1subscript~𝛾𝜂0\gamma_{\eta}(t)=\gamma_{0}+(\gamma_{T}-\gamma_{0})\frac{\tilde{\gamma}_{\eta}% (t)-\tilde{\gamma}_{\eta}(0)}{\tilde{\gamma}_{\eta}(1)-\tilde{\gamma}_{\eta}(0)}italic_γ start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT ( italic_t ) = italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + ( italic_γ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT - italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) divide start_ARG over~ start_ARG italic_γ end_ARG start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT ( italic_t ) - over~ start_ARG italic_γ end_ARG start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT ( 0 ) end_ARG start_ARG over~ start_ARG italic_γ end_ARG start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT ( 1 ) - over~ start_ARG italic_γ end_ARG start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT ( 0 ) end_ARG (90)

where γ~η(t)subscript~𝛾𝜂𝑡\tilde{\gamma}_{\eta}(t)over~ start_ARG italic_γ end_ARG start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT ( italic_t ) is a dense monotone increasing network. Concretely, we use a two-layer fully-connected network γ~η(t)=l2(ϕ(l1(t)))subscript~𝛾𝜂𝑡subscript𝑙2italic-ϕsubscript𝑙1𝑡\tilde{\gamma}_{\eta}(t)=l_{2}(\phi(l_{1}(t)))over~ start_ARG italic_γ end_ARG start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT ( italic_t ) = italic_l start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_ϕ ( italic_l start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_t ) ) ) where ϕitalic-ϕ\phiitalic_ϕ is the sigmoid activation function, l1,l2subscript𝑙1subscript𝑙2l_{1},l_{2}italic_l start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_l start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT are linear layers with positive weight and output units of 1024 and 1.

Algorithm 1 Adaptive importance sampling (single iteration)

Require: velocity network 𝒗θsubscript𝒗𝜃\bm{v}_{\theta}bold_italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT, IS network γ~ηsubscript~𝛾𝜂\tilde{\gamma}_{\eta}over~ start_ARG italic_γ end_ARG start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT, noise schedule αγ,σγsubscript𝛼𝛾subscript𝜎𝛾\alpha_{\gamma},\sigma_{\gamma}italic_α start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_σ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT, batch size N𝑁Nitalic_N

Sample 𝒙0(1),,𝒙0(N)superscriptsubscript𝒙01superscriptsubscript𝒙0𝑁\bm{x}_{0}^{(1)},\dots,\bm{x}_{0}^{(N)}bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT , … , bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_N ) end_POSTSUPERSCRIPT from data distribution

Sample ϵ(1),,ϵ(N)superscriptbold-italic-ϵ1superscriptbold-italic-ϵ𝑁\bm{\epsilon}^{(1)},\dots,\bm{\epsilon}^{(N)}bold_italic_ϵ start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT , … , bold_italic_ϵ start_POSTSUPERSCRIPT ( italic_N ) end_POSTSUPERSCRIPT from standard Gaussian distribution 𝒩(𝟎,𝑰)𝒩0𝑰\mathcal{N}(\bm{0},\bm{I})caligraphic_N ( bold_0 , bold_italic_I )

Sample t(1),,t(N)superscript𝑡1superscript𝑡𝑁t^{(1)},\dots,t^{(N)}italic_t start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT , … , italic_t start_POSTSUPERSCRIPT ( italic_N ) end_POSTSUPERSCRIPT from uniform distribution 𝒰(0,1)𝒰01\mathcal{U}(0,1)caligraphic_U ( 0 , 1 )

Caculate γη(t(i)),i=1,,Nformulae-sequencesubscript𝛾𝜂superscript𝑡𝑖𝑖1𝑁\gamma_{\eta}(t^{(i)}),i=1,\dots,Nitalic_γ start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT ( italic_t start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) , italic_i = 1 , … , italic_N by Eqn. (90)

Fix η𝜂\etaitalic_η, optimize θ𝜃\thetaitalic_θ to minimize 1Ni=1Nθ,η(𝒙0(i),ϵ(i),t(i))1𝑁superscriptsubscript𝑖1𝑁subscript𝜃𝜂superscriptsubscript𝒙0𝑖superscriptbold-italic-ϵ𝑖superscript𝑡𝑖\frac{1}{N}\sum_{i=1}^{N}\mathcal{L}_{\theta,\eta}(\bm{x}_{0}^{(i)},\bm{% \epsilon}^{(i)},t^{(i)})divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_θ , italic_η end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT , bold_italic_ϵ start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT , italic_t start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT )

Fix θ𝜃\thetaitalic_θ, optimize η𝜂\etaitalic_η to minimize 1Ni=1Nθ,η2(𝒙0(i),ϵ(i),t(i))1𝑁superscriptsubscript𝑖1𝑁subscriptsuperscript2𝜃𝜂superscriptsubscript𝒙0𝑖superscriptbold-italic-ϵ𝑖superscript𝑡𝑖\frac{1}{N}\sum_{i=1}^{N}\mathcal{L}^{2}_{\theta,\eta}(\bm{x}_{0}^{(i)},\bm{% \epsilon}^{(i)},t^{(i)})divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT caligraphic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ , italic_η end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT , bold_italic_ϵ start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT , italic_t start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT )

Then we present our adaptive IS procedure in Algorithm 1. Kingma et al. (2021) proposes to reuse the gradient θθ,η(𝒙0,ϵ,t)subscript𝜃subscript𝜃𝜂subscript𝒙0bold-italic-ϵ𝑡\nabla_{\theta}\mathcal{L}_{\theta,\eta}(\bm{x}_{0},\bm{\epsilon},t)∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_θ , italic_η end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ , italic_t ) to optimize η𝜂\etaitalic_η and avoid a second backpropagation by decomposing the gradient ηθ,η2(𝒙0,ϵ,t)subscript𝜂superscriptsubscript𝜃𝜂2subscript𝒙0bold-italic-ϵ𝑡\nabla_{\eta}\mathcal{L}_{\theta,\eta}^{2}(\bm{x}_{0},\bm{\epsilon},t)∇ start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_θ , italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_ϵ , italic_t ) using chain-rule. We simply their learning of γ~ηsubscript~𝛾𝜂\tilde{\gamma}_{\eta}over~ start_ARG italic_γ end_ARG start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT by removing the complex gradient operation in one iteration and propose to alternatively optimize θ𝜃\thetaitalic_θ and η𝜂\etaitalic_η. It may take extra overhead, but also seeks the optimal IS and is enough for ablation.

Appendix I Experiment details

In this section, we provide details of our experiment settings. Our network, hyperparameters and training are the same for different noise schedules on the same dataset.

Model architectures

Our diffusion ODEs are parameterized in terms of the γ𝛾\gammaitalic_γ-timed normalized velocity predictor 𝒗~θ(𝒙γ,γ)subscript~𝒗𝜃subscript𝒙𝛾𝛾\tilde{\bm{v}}_{\theta}(\bm{x}_{\gamma},\gamma)over~ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT , italic_γ ), based on the U-Net structure of Kingma et al. (2021). This architecture is tailored for maximum likelihood training, employing special designs such as removing the internal downsampling/upsampling and adding Fourier features for fine-scale prediction. Our configuration for each dataset also follows Kingma et al. (2021): For CIFAR-10, we use U-Net of depth 32 with 128 channels; for ImageNet-32, we still use U-Net of depth 32, but double the number of channels to 256. All our models use a dropout rate of 0.1 in the intermediate layers. For CIFAR-10 (with data augmentation), we use U-Net of depth 32 with 256 channels and decrease the dropout rate to 0.05.

Hyperparameters and training

We follow the same default training settings as Kingma et al. (2021). For all our experiments, we use the Adam (Kingma & Ba, 2014) optimizer with learning rate 2×1042superscript1042\times 10^{-4}2 × 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT, exponential decay rates of β1=0.9,β2=0.99formulae-sequencesubscript𝛽10.9subscript𝛽20.99\beta_{1}=0.9,\beta_{2}=0.99italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = 0.9 , italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = 0.99 and decoupled weight decay (Loshchilov & Hutter, 2019) coefficient of 0.01. We also maintain an exponential moving average (EMA) of model parameters with an EMA rate of 0.9999 for evaluation.

For other hyperparameters, we use fixed start and end times which satisfy γϵ=13.3,γT=5.0formulae-sequencesubscript𝛾italic-ϵ13.3subscript𝛾𝑇5.0\gamma_{\epsilon}=-13.3,\gamma_{T}=5.0italic_γ start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT = - 13.3 , italic_γ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT = 5.0, which is the default setting in Kingma et al. (2021). In the finetuning stage, we simply set the coefficient λ𝜆\lambdaitalic_λ in the mixed loss 𝒥FM(θ)+λ𝒥FM,tr(θ)subscript𝒥FM𝜃𝜆subscript𝒥FMtr𝜃\mathcal{J}_{\text{FM}}(\theta)+\lambda\mathcal{J}_{\text{FM},\mathrm{tr}}(\theta)caligraphic_J start_POSTSUBSCRIPT FM end_POSTSUBSCRIPT ( italic_θ ) + italic_λ caligraphic_J start_POSTSUBSCRIPT FM , roman_tr end_POSTSUBSCRIPT ( italic_θ ) as 0.1 with no further tuning, so that the magnitude of the second-order loss is negligible w.r.t the first-order loss. Since the first-order matching accuracy is critical to the second-order matching, a large λ𝜆\lambdaitalic_λ will make the training unstable or even degenerate the likelihood performance.

All our training processes are conducted on 8 GPU cards of NVIDIA A40 except for ImageNet-32 (old version) and CIFAR-10 (with data augmentation). For CIFAR-10, we pretrain the model for 6 million iterations, which takes around 3 weeks. Then we finetune the model for 200k iterations, which takes around 1 day. For ImageNet-32 (new version), we pretrain the model for 2 million iterations, which takes around 2 weeks. Then we finetune the model for 250k iterations, which takes around 3 days. We use a batch size of 128 for both training stages and both datasets.

Note that in related works (Lipman et al., 2022; Albergo & Vanden-Eijnden, 2022), experiments on ImageNet-32 (new version) are conducted at a larger batch size (512 or 1024), which may improve the results. We did not use a larger batch size or train longer due to resource limitations.

For ImageNet-32 (old version), the training processes are conducted on 8 GPU cards of NVIDIA A100 (40GB). We pretrain the model for 2 million iterations using a batch size of 512, which takes around 2 weeks. Then we finetune the model for 500k iterations using a batch size of 128 and accumulate the gradient for every 4 batches, which takes around 2.5 days.

For CIFAR-10 (with data augmentation), the training processes are conducted on a cluster of 64 GPU cards of NVIDIA A800 (80GB). We pretrain the model for 2 million iterations using a batch size of 1024, which takes around 2 weeks. Due to the large training resource requirements and the regularization effect by data augmentation, we do not further finetune the model by the second-order flow matching loss.

Likelihood and sample quality

For likelihood, we use our truncated-normal dequantization. When the number of importance samples K=1𝐾1K=1italic_K = 1, we report the BPD on the test dataset with 5 times repeating to reduce the variance of the trace estimator. When K=5𝐾5K=5italic_K = 5 or K=20𝐾20K=20italic_K = 20, we do not repeat the dataset since the log-likelihood of a data sample is already evaluated multiple times. For sampling, since we are concentrated on ODE, we simply use an adaptive-step ODE solver with RK45 method (Dormand & Prince, 1980) (relative tolerance 105superscript10510^{-5}10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT and absolute tolerance 105superscript10510^{-5}10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT). We generate 50k samples and report the FIDs on them. Utilizing high-quality sampling procedures such as PC sampler (Song et al., 2021c) or fast sampling algorithms such as DPM-Solver (Lu et al., 2022b) may improve the results, which are left for future work.

Appendix J Additional samples

Refer to caption (a) VDM(Kingma et al., 2021)
Refer to caption (b) Our pretrain
Refer to caption (c) Our pretrain+finetune
Figure 6: Random samples for ablation by ODE sampler. Our pretraining and finetuning lead to a better likelihood with small visual quality degeneration.
Refer to caption
Figure 7: Random samples by ODE sampler (CIFAR-10, SP).
Refer to caption
Figure 8: Random samples by ODE sampler (CIFAR-10, VP).
Refer to caption
Figure 9: Random samples by ODE sampler (CIFAR-10, VP, with augmentation).
Refer to caption
Figure 10: Random samples by ODE sampler (ImageNet-32, VP).
Refer to caption
Figure 11: Random samples by ODE sampler (ImageNet-32, SP).