We can't optimize the 1st part of the final expression, so all that's left is to optimize Eq(x0)[−logpθ].
Eq(x0) is expectation under the data distribution. This is equivalent to sampling from the data distribution.
We currently have: pθ(x0). But the reverse diffusion process depends on latents x1 through xT.
We need to pull these into our expression. To do that, we use marginalization.
What's marginalization?
Let A be a random variable that represents the probability of it raining tomorrow.
Let B be a random variable that represents some other event. Maybe: Tom catches Jerry today. (Note how the 2 events don't necessarily have to be related). The probability of it raining tomorrow can also be expressed as "probability of it raining tomorrow & probability that Tom catches Jerry today" + "probability of it raining tomorrow & probability that Tom does not catch Jerry today "
In this case, both A and B can take on 2 possible values.
P(A=a)=b∑P(A=a,B=b)
If this was continuous, we would have:
P(A=a)=∫P(A=a,B=b)\db
Side note, in this case, A and B are independent, so we can do:
We needed to "pull in" all the latent variables that x0 depends on. To do this, we marginalized x0 with respect to all the latent variables in the reverse diffusion process.
We scale the mean by 1−βt in order to prevent the values/variance from exploding.
TODO: Be more precise here. Write some tests & verify what is happening.
I originally thought it was b/c we want to scale the mean down to 0. Indeed, if we put in data composed of all 1s into the forward diffusion process, the end result would have a mean of 0 and a stdev of 1.
But, that isn't the goal of scaling the mean. In diffusion models, we assume our input has a mean of 0. Because if our forward diffusion process shifts the mean to 0, then our neural net must learn to shift the mean back. That's probably bad since our neural net uses the same weights for each reverse diffusion step.
Aside: Training-Inference Compute Asymetry
Diffusion models allow you to train with much less compute, while still utilizing a ton of compute during sampling time.
Aside: Functional Form
Both processes have the same functional form when βt are small
This is saying "we need a lot of timesteps in the forward/reverse diffusion processes" in order for both forward/reverse processes to be gaussian.
Example: Imagine your dataset consists of cats & dogs. If you do a single reverse step to go from noise to "full cat/dog" then your reverse step can't produce a gaussian distribution of outcomes. How do you fit a gaussian to cats & dogs? You would need a multimodal distribution instead.
Use the fact that pθ and q are defined as products + log rules.
A note about p(xT):
Technically, it should be pθ(xT), but we assume that after enough forward diffusion steps, p(xT) is identical to the normal distribution (hence the equation earlier in the paper: p(xT)=N(xT;0,I)).
Concretely, to calculate logp(xT) we could:
Take an image from our data set, x0
Run forward diffusion for T steps to get T.
Calculate the probability of xT appearing from the normal distribution, N.
TODO: I'm guessing you would do this by calculating the probability of the normal distribution discretizing to the given image?
Extract the 1st term out of the summation. This is necessary because when we apply Bayes' Theorem to q, we will get a non-sensical result if we also apply it to the 1st term in the summation.
Apply Bayes' rule to q(xt∣xt−1)
We need to condition the reverse conditional probability, q, on x0. Why? q(xt−1∣xt) needs to give the probability distribution of xt−1s given xt, but this might be extremely difficult if, e.g, xt has a lot of noise (which it will near the end of the diffusion process). If we know the original image, x0, this process becomes easy. This also makes the reverse conditional probability tractable. I.e., we can compute it.
Log rules.
Expand the 2nd summation, apply log rules to get the log of a cumulative product, cancel terms.
We go from βt to βt because βtI is a covariance matrix. Specifically, it's the diagonal covariance matrix, so it only has covariance values of the form COV(X,X)=VAR(X)=σ2
Substitute βt with at
Substitute using xt−1=αt−1xt−2+1−αt−1ϵ
Algebra
ϵ is sampled from the normal distribution ϵ=N(0,I) Thus, the last 2 terms are gaussians with mean 0 and standard deviation at1−at−1 and 1−at. VAR(X+Y)=VAR(X)+VAR(Y) if X and Y are independent. Thus, the sum of those two terms is a gaussian with mean 0 and variance αt(1−αt−1)+(1−αt)=1−αtαt−1
One last note, x is actually xt−1 since that is the output of q(xt−1∣xt,x0)=N(xt−1;μ~(xt,x0),β~tI).
Now we can derive the mean and variance of q(xt−1∣xt,x0):
I could not figure out the algebra for this step. But, I did verify the results numerically.
See these tests.
Closed Form KL-Divergence
We want to calculate the KL-divergence efficiently.
Normally, we would need to Monte Carlo estimation.
But, since Lt−1 is comparing 2 gaussians, we can use the closed form of the KL Divergence:
DKL(p,q)=logσ1σ2+2σ22σ12+(μ1−μ2)2−21
The authors ignore the 1st & last terms since they don't learn the variance for the reverse diffusion process.
Question: Why bother including σt in the denominator given that we're not learning? (My guess: resembles denoised score matching)
Extensions
DDIM
DDIM allows deterministic sampling without modifying the training process.
Let's analyze the OpenAI code for DDIM sampling:
def ddim_sample(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
model_kwargs=None,
eta=0.0,
):
"""
Sample x_{t-1} from the model using DDIM.
Same usage as p_sample().
"""
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
# Usually our model outputs epsilon, but we re-derive it
# in case we used x_start or x_prev prediction.
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
# ===MARKER 1===
sigma = (
eta
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
)
# Equation 12.
noise = th.randn_like(x)
# ===MARKER 2===
mean_pred = (
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
+ th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
sample = mean_pred + nonzero_mask * sigma * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
Where ϵθ is the random noise predicted by the network and ϵ is random noise we generate.
Note, when eta=0 , then this simplifies to:
xt−1=aˉt−1x0+1−aˉt−1ϵθ
which is a deterministic sampling process! In essence, instead of sampling from the predicted posterior mean/variance, we just use the reparameterization trick and stop 1-step earlier.