当前位置:网站首页>Score-Based Generative Modeling through Stochastic Differential Equations

Score-Based Generative Modeling through Stochastic Differential Equations

2022-06-21 15:49:00 Steamed bread and flower rolls

Song Y., Sohl-Dickstein J., Kingma D. P., Kumar A., Ermon S. and Poole B. Score-based generative modeling through stochastic differential equations. In International Conference on Learning Representations (ICLR), 2021

General

from stochastic differential equation (SDE) Perspective diffusion models.

Symbol description

  • \(\bm{x}(t), t \in [0, T]\) by \(\bm{x}\) In time \(t\) A state of ;
  • \(p_t(\bm{x}) = p(\bm{x}(t))\), \(\bm{x}\) In time \(t\) The distribution obeyed ;
  • \(p_{st}(\bm{x}(t)|\bm{x}(s)), 0 \le s < t \le T\), from \(\bm{x}(s)\) To \(\bm{x}(t)\) The transfer nucleus of (transition kernel);
  • \(\bm{s}_{\theta}(\bm{x}, t)\), by score \(\nabla_{\bm{x}} \log p_t(\bm{x})\) An approximation of , Neural network is usually used to fit .

Wiener process

Wiener process \(X(t, w)\) It is such a random process :

  1. \(X(0) = 0\);
  2. \(X(t+\Delta t) - X(t)\) and \(X(s)\) It's independent ( Feeling is Mahalanobis );
  3. \(X(t + \Delta t) - X(t) \sim \mathcal{N}(0, \Delta t)\), The obedience variance is \(\Delta t\) Is a normal distribution ;
  4. \(\lim_{\Delta \rightarrow 0} X(t + \Delta t) = X(t)\), About \(t\) Is a continuous .

This article focuses on the belt drift \(\mu\) Of Wiener random process :

\[X(t, w) = \mu t + \sigma W_t, \]

among \(W_t\) Obey the general Wiener process.

We can use the following SDE To describe the increment in the random process ( General form ):

\[\tag{SDE+} \text{d} \bm{x} = \bm{f}(\bm{x}, t) \text{d} t + \bm{G}(\bm{x}, t) \text{d} \bm{w}, \]

among

\[\bm{f}(\cdot, t): \mathbb{R}^d \rightarrow \mathbb{R}^d, \\ \bm{G}(\cdot, t): \mathbb{R}^d \rightarrow \mathbb{R}^{d \times d}. \]

among \(\text{d} \bm{w}\) Especially in general Wiener process Increment in , namely \(\bm{w}(t + \Delta t) - \bm{w}(t) \sim \mathcal{N}(\bm{0}, \Delta t)\).

Its inverse process can be described as :

\[\tag{SDE-} \text{d} \bm{x} = \{ \bm{f}(\bm{x}, t) - \nabla \cdot [\text{G}(\bm{x}, t) \bm{G}(\bm{x}, t)^T] - \bm{G}(\bm{x}, t) \bm{G}(\bm{x}, t)^T \nabla_{\bm{x}} \log p_t(\bm{x}) \} \text{d} t + \bm{G}(\bm{x}, t) \text{d} \bm{w}. \]

primary coverage

SMLD and DDPM Adopted :

  1. \(\bm{x}(0) \rightarrow \bm{x}(T)\), The process of gradually adding noise ;
  2. \(\bm{x}(T) \rightarrow \bm{x}(0)\), Step by step sampling process .

These two equations can be regarded as two ( Pros and cons ) SDE Discrete process of .

Reverse sampling

Let's start with reverse sampling , This will make it easier to understand some of the designs in forward . We know , Once there is (SDE-) and score function \(\nabla_x \log p_t(\bm{x})\), It can be solved step by step by some discrete methods ' Generate ' Explain \(\bm{x}(0)\) 了 .

Numerical SDE solvers

There are many numerical methods that can be used for back sampling : Euler-Maruyama, stochastic Runge-Kutta methods, Ancestral sampling.

This paper puts forward a kind of reverse diffusion sampling (Ancestral sampling It is a special case of this ):

  1. about

    \[\text{d} \bm{x} = \bm{f}(\bm{x}, t) \text{d} t + \bm{G}(\bm{x}, t) \text{d} \bm{w}, \]

    use

    \[\bm{x}_{i + 1} = \bm{x}_i + \bm{f}_i(\bm{x}_i) + G_i \bm{z}_i, i=0,1,\cdots, N - 1 \]

    How to update ;
  2. Similarly , about ( simplify )

    \[\text{d} \bm{x} = \{ \bm{f}(\bm{x}, t) - \bm{G}(\bm{x}, t) \bm{G}(\bm{x}, t)^T \nabla_{\bm{x}} \log p_t(\bm{x}) \} \text{d} t + \bm{G}(t) \text{d} \bm{w}, \]

    use ( Be careful , Symbol is back Of )

    \[\bm{x}_i = \bm{x}_{i + 1} - \bm{f}_{i+1}(\bm{x}_{i+1}) + \bm{G}_{i+1} \bm{G}_{i+1}^T \nabla_{\bm{x}} \log p_{i+1}(\bm{x}_{i+1}) + \bm{G}_{i+1} \bm{z}_{i+1}. \]

Predictor-corrector samplers

Suppose we know \(\nabla_x \log p_t(\bm{x})\) Or an approximation of it \(\bm{s}_{\theta}(\bm{x}, t)\). We can go through score-based MCMC To sample , such as Langevin MCMC and HMC (here).

utilize Langevin MCMC, Steps are as follows :

\[\bm{x} \leftarrow \bm{x} + \epsilon \nabla_x \log p(\bm{x}) + \sqrt{2\epsilon} \bm{z}, \: \bm{z} \mathop{\sim} \limits^{i.i.d.} \mathcal{N}(\bm{0}, I), \]

among \(\epsilon\) Step length .

notes : MCMC The process of sampling is to ensure that the points sampled continuously eventually tend to be distributed \(p(\bm{x})\), Instead of saying that the whole process produces some conformity inverse random process !

Holistic PC samplers The framework is as follows :

among Predictor It can be arbitrary numeric solvers, Corrector yes MCMC. This is equivalent to , The stochastic process is solved numerically , But due to errors , May lead to actual \(\bm{x}_i\) Deviate from its distribution , Therefore, it is necessary to pass MCMC To correct .

Probability Flow

This part , The author will SDE Convert to a ODE, So as to sample deterministically , But I didn't understand this part , Just write it down here . It should be noted that , and SDE Dissimilarity , because ODE Excluding random items , So we can use the ready-made black-box ODE solver To solve the equation , And by giving different \(\bm{x}(T) \sim p_T\), Then we can have different solutions .

The general process is as follows :

\[\bm{x}_i = \bm{x}_{i + 1} - \bm{f}_{i + 1}(\bm{x}_{i + 1}) + \frac{1}{2}G_{i+1}G_{i+1}^T \bm{s}_{\theta}(\bm{x}_{i + 1}, i + 1), \: i=0, 1, \cdots, N - 1. \]

Conditional sampling

Conditional sampling , I.e. given \(\bm{y}(0)\), We want to start with the conditional distribution

\[p(\bm{x}(0) |\bm{y}(0)) \]

In the sample . Generally speaking , We'll get through the Bayes formula

\[p(\bm{x}(0) |\bm{y}(0)) = \frac{p(\bm{y}(0)|\bm{x}(0)) p(\bm{x}(0))}{p(\bm{y}(0))}, \]

But it is often difficult to estimate a priori \(p(\bm{x}(0))\) and \(p(\bm{y}(0))\).

We can use the following inverse-time SDE Come from \(p_t(\bm{x}(t) | \bm{y})\) In the sample :

\[\text{d} \bm{x} = \{ \bm{f}(\bm{x}, t) - \nabla \cdot [\text{G}(\bm{x}, t) \bm{G}(\bm{x}, t)^T] - \bm{G}(\bm{x}, t) \bm{G}(\bm{x}, t)^T \nabla_{\bm{x}} \log p_t(\bm{x}(t)|\bm{y}(0)) \} \text{d} t + \bm{G}(\bm{x}, t) \text{d} \bm{w}. \]

also

\[\nabla_x \log p_t (\bm{x}(t)|\bm{y}(0)) = \underbrace{\nabla_x \log p_t(\bm{x}(t))}_{\approx s_{\theta}(\bm{x}, t)} + \nabla_{x} \log p_t(\bm{y}(0)|\bm{x}(t)), \]

Therefore, when \(\nabla_x \log p_t (\bm{y}(0)|\bm{x}(t))\) Knowable time , We can take samples .

Next , We discuss \(p_t(\bm{y}(0)|\bm{x}(t))\) Estimable and difficult to estimate directly

An estimable situation
  1. \(\bm{y}(0)\) For the tags in the classification task ;
  2. sampling \(\bm{x}(t)\);
  3. utilize Cross entropy loss Train one time-dependent classifier :

    \[p_t(\bm{y}(0) | \bm{x}(t)). \]

An incalculable situation

At this point we notice :

\[\nabla_x \log p_t(\bm{x}(t)|\bm{y}) = \nabla_x \log \int p_t(\bm{x}(t) | \bm{y}(t), \bm{y}(0)) p(\bm{y}(t) | \bm{y}(0)) \text{d} \bm{y}(t). \]

We give the following two reasonable assumptions :

  1. \(p(\bm{y}(t) | \bm{y}(0))\) It's negotiable ;
  2. \(p_t(\bm{x}(t)|\bm{y}(t), \bm{y}(0)) \approx p_t(\bm{x}(t)|\bm{y}(t))\), This is because for \(t\) A relatively small situation , \(\bm{y}(t) \approx \bm{y}(0)\), And for \(t\) In the larger case , \(\bm{x}(t)\) suffer \(\bm{y}(t)\) Maximum impact .

At this time there is

\[\begin{array}{ll} \nabla_x \log p_t(\bm{x}(t)|\bm{y}(0)) &\approx \nabla_x \log \int p_t(\bm{x}(t) | \bm{y}(t)) p(\bm{y}(t) | \bm{y}(0)) \text{d} \bm{y}(t) \\ &\approx \log p_t(\bm{x}(t)|\hat{\bm{y}}(t)) \: \leftarrow \hat{\bm{y}}(t) \sim p(\bm{y}(t)|\bm{y}(0)) \\ &=\nabla \log_x p_t(\bm{x}(t)) + \nabla_x \log p_t(\hat{\bm{y}}(t)|\bm{x}(t)) \\ &\approx \bm{s}_{\theta} (\bm{x}(t), t) + \nabla_x \log p_t(\hat{\bm{y}}(t) | \bm{x}(t)). \end{array} \]

At this point, just \(\nabla_x \log p_t(\hat{y}(t)|\bm{x}(t))\) It can be substituted into the solution .

Let's say Imputation For example . hypothesis \(\Omega(\bm{x}), \bar{\Omega}(\bm{x})\) respectively Observed and The lack of part . Our aim is to start from

\[p(\bm{x}(0) | \Omega(\bm{x}(0)) = \bm{y}) \]

In the sample . Follow the steps above , We just need to estimate

\[\nabla_x \log p_t (\bm{x}(t) | \hat{\Omega}(\bm{x}(t)) ) \]

that will do . actually , Note that the modeling in this paper is element-wise Of , therefore

\[p_t (\bm{x}(t) | \hat{\Omega}(\bm{x}(t)) ) = p_t (\bm{x}_{\hat{\Omega}}(t)), \]

That is, only \(\hat{\Omega}\) The area needs to be sampled .

notes : The content here and the original text Appendix I.2 There is a big discrepancy in the derivation of , I come from my own understanding , No experiments have been done , Accuracy in question !

Forward disturbance

According to the previous process , We know , If we can estimate

\[\bm{s}_{\theta}(\bm{x}, t) \approx \nabla_x \log p_t (\bm{x}), \]

Then we can follow the random process step by step , And this requires (denosing) score matching As a training goal :

\[\theta^* = \mathop{\arg \min} \limits_{\theta} \mathbb{E}_t \Bigg\{ \lambda (t) \mathbb{E}_{\bm{x}(0)} \mathbb{E}_{\bm{x}(t)|\bm{x}(0)} [\|\bm{s}_{\theta}(\bm{x}(t), t) - \nabla_{\bm{x}(t)} \log p_{0t} (\bm{x}(t)|\bm{x}(0))\|_2^2] \Bigg\}, \]

among \(\lambda(\cdot)\) Is a positive weight , Usually choose \(\lambda \propto 1 / \mathbb{E} [\|\nabla_{\bm{x}(t)} \log p_{0t} (\bm{x}(t)|\bm{x}(0))\|_2^2]\), \(t \sim \mathcal{U}[0, T]\).

From the definition of the objective function above , Generally speaking , Only \(p_{0t}\) It is explicitly computable that the above is meaningful , For more general stochastic processes , It can be used slice score matching To bypass the complicated calculations ( But at the cost of more computation ). The following describes , Are all computable Gaussian distributions .

SMLD

SMLD Defined \(\{\bm{x}_i\}_{i=1}^N\), It can be seen as \(t = \frac{i}{N} \in [0, T = 1]\) A discrete random process :

\[\tag{1} \bm{x}_i = \bm{x}_{i-1} + \sqrt{\sigma_i^2 - \sigma_{i-1}^2} \bm{z}_{i-1}, \: \bm{z}_i \mathop{\sim} \limits^{i.i.d.} \mathcal{N}(\bm{0}, I). \]

And meet

\[\sigma_{\min} = \sigma_1 < \sigma_2 < \cdots < \sigma_N = \sigma_{\max}. \]

At this time there is :

\[\bm{x}_i|\bm{x}_0 \sim \mathcal{N}(\bm{x}_0, \sigma_i^2 I). \]

We further rewrite it as SDE In the form of ( But even \(N \rightarrow \infty\) ):

\[\Delta \bm{x}(t) = \bm{x}(t + \Delta) - \bm{x}(t) = \sqrt{\Delta \sigma^2 (t)} \bm{z}(t) = \sqrt{\frac{\Delta \sigma^2(t)}{\Delta t} \Delta t} \bm{z}(t), \]

When \(\Delta t \rightarrow 0\) when ( namely \(N \rightarrow \infty\) ) Yes :

\[\Delta \bm{x}(t) \rightarrow \text{d} \bm{x}(t), \\ \frac{\Delta \sigma^2 (t)}{\Delta t} \rightarrow \frac{\text{d}[\sigma^2(t)]}{\text{d}t}. \]

Last , It's easy for us to see The incremental \(\sqrt{\Delta t} \bm{z}(t) \sim \mathcal{N}(\bm{0}, \Delta t)\), The random process formed naturally satisfies Wiener process, so

\[\tag{2} \text{d}\bm{x} = \bm{0} \text{d}t + \sqrt{\frac{\text{d} \sigma^2 (t)}{\text{d} t}} \text{d} \bm{w}. \]

That there is no drift The amount .

DDPM

DDPM Defined \(\{\bm{x}_i\}_{i=1}^N\), It can be seen as \(t = \frac{i}{N} \in [0, T = 1]\) A discrete random process :

\[\tag{3} \bm{x}_i = \sqrt{1 - \beta_i} \bm{x}_{i-1} + \sqrt{\beta_i} \bm{z}_{i-1}, \bm{z}_i \mathop{\sim} \limits^{i.i.d.} \mathcal{N}(\bm{0}, I). \]

Make \(\bar{\beta}_i := N \beta_i\), And define

\[\beta(t), t \in [0, 1], \: \beta(\frac{i}{N}) = \bar{\beta_i}. \]

be (3) I could rewrite it as

\[\tag{3+} \bm{x}(t + \Delta t) - \bm{x}(t) = (\sqrt{1 - \beta(t + \Delta t) \Delta t} - 1) \bm{x}(t) + \sqrt{\beta (t + \Delta t) \Delta t} \bm{z}(t), \]

When \(\Delta \rightarrow 0\), Yes

\[\bm{x}(t + \Delta t) - \bm{x}(t) = \Delta \bm{x}(t) \rightarrow \text{d} \bm{x}(t) \\ \sqrt{1 - \beta(t + \Delta t) \Delta t} - 1 \rightarrow -\frac{1}{2} \beta (t) \text{d} t \\ \sqrt{\beta (t + \Delta t) \Delta t} \bm{z}(t) \rightarrow \sqrt{\beta (t)} \text{d}\bm{w}. \]

The second term can be obtained by the first-order Taylor approximation , Item 2 and SMLD The reasoning in is similar .

Last , It can be summarized as follows Wiener process:

\[\tag{4} \text{d}\bm{x} = -\frac{1}{2} \beta (t) \bm{x} \text{d} t + \sqrt{\beta (t)} \text{d}\bm{w}. \]

Next, let's deduce DDPM Of \(\bm{x}(t)\) The conditional distribution of . (3+) Both sides take the expectation

\[\bm{e}(t + \Delta t) - \bm{e}(t) = (\sqrt{1 - \beta(t + \Delta t) \Delta t} - 1) \bm{e}(t) + \bm{0}, \]

among \(\bm{e}(t) = \mathbb{E}[\bm{x}(t)]\), be

\[\text{d} \bm{e} = -\frac{1}{2} \beta (t) \bm{e} \text{d} t, \]

Add the initial value condition \(\bm{e}(0) = \bm{e}_0\), Available :

\[\bm{e}(t) = \bm{e}(0) e^{-\frac{1}{2} \int_0^t \beta (s) \text{d}s}. \]

and \(\bm{x}(t)\) The covariance matrix of \(\Sigma_{VP}(t)\) Satisfy

\[\text{d}\Sigma_{VP}(t) = \beta (t) (I - \Sigma_{VP}(t)) \text{d}t, \]

Plus the initial value \(\Sigma_{VP}(0)\) Available

\[\Sigma_{VP}(t) = I + e^{-\int_0^t \beta(s) \text{d}s} (\Sigma_{VP}(0) - I). \]

Hence obedience

\[\bm{x}(t) \sim \mathcal{N}(\bm{e}(0) e^{-\frac{1}{2} \int_0^t \beta (s) \text{d}s}; I + e^{-\int_0^t \beta(s) \text{d}s}(\Sigma_{VP}(0) - I)) \]

In known \(\bm{x}(0)\) Under the condition of , \(\bm{e}(0) = \bm{x}(0), \Sigma_{VP}(0) = 0\), so

\[\bm{x}(t)|\bm{x}(0) \sim \mathcal{N}(\bm{x}(0) e^{-\frac{1}{2} \int_0^t \beta (s) \text{d}s}; I - e^{-\int_0^t \beta(s) \text{d}s}I) \]

notes : The derivation of the variance formula is in another paper , The variance solution here is a general basis .

expand

adopt SMLD and DDPM Two examples can be found , We just need to customize \(\bm{f}(\bm{x}, t)\) and \(\bm{G}(\bm{x}, t)\), Different forward perturbation processes can be constructed . actually , SMLD and DDPM The two are different SDE: Variance Exploding (VE) SDE and Variance Preserving (VP) SDE. This is because SMLD requirement \(\sigma_{\max} \rightarrow \infty\) From the above derivation, we can get , if \(\Sigma_{VP}(0) = I\) perhaps \(\int_{0}^t \beta (s) \text{d}s \rightarrow +\infty\) when , The variances are convergent .

sub-VP SDE

suffer DDPM VP SDE Inspired by nature , A new forward perturbation process is designed :

\[\text{d}\bm{x} = -\frac{1}{2} \beta(t) \bm{x} \text{d}t + \sqrt{\beta (t) (1 - e^{-2 \int_0^t \beta (s) \text{d} s})} \text{d} \bm{w}. \]

and DDPM equally , \(\bm{x}(t)\) The expectations of the

\[\mathbb{E}[\bm{x}(t)] = \mathbb{E}[\bm{x}(0)] e^{-\frac{1}{2} \int_0^t \beta (s) \text{d}s}. \]

And the covariance is

\[\Sigma_{sub-VP}(t) := \text{Cov}[\bm{x}(t)] = I + e^{-2\int_0^t \beta(s) \text{d}s} I + e^{-\int_0^t \beta(s) \text{d}s} (\Sigma_{sub-VP}(0) - 2I). \]

It has two properties :

  1. When \(\Sigma_{VP}(0) = \Sigma_{sub-VP}(0)\) when , \(\Sigma_{sub-VP} \preceq \Sigma_{VP}\), That is, it has a smaller variance ;
  2. \(\lim_{t \rightarrow} \Sigma_{sub-VP}(t) = I\) When \(\int_0^{+\infty} \beta(s) \text{d} s = +\infty\).

In addition, its conditional distribution is :

\[\bm{x}(t)|\bm{x}(0) \sim \mathcal{N}(\bm{x}(0) e^{-\frac{1}{2} \int_0^t \beta (s) \text{d}s}; (1 - e^{-\int_0^t \beta(s) \text{d}s})^2 I). \]

Specific sampling algorithm

PC sampling

Corrector

here , The author directly constructs the step size , It should be noted that , there \(r\) Represents signal to noise ratio .

Other details

  • Network structure : and DDPM In the same ;
  • Training uses \(N=1000\) scales;
  • When sampling , final \(\bm{x}(0)\) It will be invisible to the human eye but will affect FID Noise of indicators , So it needs to be at the end and DDPM Connect the de-noising link (Tweedies' formula);
  • Although the training time to take \(N=1000\), But when sampling, you can \(N=2000\) Even more , Interpolation is required at this time , such as

\[\bm{s}_{\theta}' (\bm{x}, i) \rightarrow \bm{s}_{\theta}' (\bm{x}, i / 2), \\ \bm{s}_{\theta}' (\bm{x}, i) \rightarrow \bm{s}_{\theta}' (\bm{x}, \lfloor i / 2 \rfloor);\\ \]

  • The most optimal Signal-to-noise ratio (singal-to-noise) \(r\) As shown in the figure below :

Code

[official]

原网站

版权声明
本文为[Steamed bread and flower rolls]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/172/202206211334467723.html