DENOISING DIFFUSION IMPLICIT MODELS (DDIM)
We know from DDPM that its diffusion process (forward process, or add-noise process) is defined as a Markov process, and its denoising process (also called inverse process) is also a Markov process. The dependence on Markov assumptions leads to the fact that each step of the reconstruction needs to depend on the state of the previous step, so the inference requires a larger number of steps.
The modeling of its inverse distribution in the DDPM uses Markov assumptions, and in doing so, the unknown terms in Eq.\(q(x_t|x_{t-1},x_0)\), which translates into a known term\(q(x_t|x_{t-1})\)and finally find the\(q(x_{t-1}|x_t,x_0)\) The distribution is also a Gaussian distribution\(\mathcal{N}(x_{t-1};\mu_q(x_t,x_0),\Sigma_q(t))\)。
Starting from the conclusions of the DDPM, we might as well just assume that\(q(x_{t-1}|x_t,x_0)\) The distribution is Gaussian, and without using Markov's assumption, try to solve for the\(q(x_{t-1}|x_t,x_0)\) 。
By DDPM\(q(x_{t-1}|x_t,x_0)\) distribution of\(\mathcal{N}(x_{t-1};\mu_q(x_t,x_0),\Sigma_q(t))\) It can be shown that the mean value of a function on\(x_t,x_0\) as a function of the variance of a function of\(t\) The function of the
We can put\(q(x_{t-1}|x_t,x_0)\) Designed to be distributed as follows:
In this way, it is sufficient to solve for\(a,b,\sigma_t\) These three coefficients to be determined, one can determine\(q(x_{t-1}|x_t,x_0)\) The distribution of the
reparameterization\(q(x_{t-1}|x_t,x_0)\) :
Assuming that the noise addition parameters of the input noisy images when training the model are identical to those of the DDPM
leave it (to sb)\(q(x_t|x_{0}) := \mathcal{N}(x_t;\sqrt{\bar{\alpha}_t}x_{0},(1-\bar{\alpha}_t)I)\) :
substitute into\(x_t\) There is:
Also:
Observing the coefficients yields a system of equations:
Three unknowns Two equations that can be used\(\sigma_t\) indicate\(a,b\):
\(a, b\) substitute into\(q(x_{t-1}|x_t,x_0) := \mathcal{N}(x_{t-1}; a x_0 + b x_t,\sigma_t^2 I)\)
both... and...
substitute into\(x_0\) There is:
Through observation\(x_{t-1}\) distribution, we model the sampling distribution as a Gaussian distribution:
and the mean and variance are in a similar form:
included among these\(\epsilon_\theta(x_t,t)\) is the predicted noise.
At this point, determining the optimization objective only requires\(q(x_{t-1}|x_t,x_0)\) cap (a poem)\(p_\theta(x_{t-1}|x_t)\) The two distributions are as similar as possible, and using the KL dispersion measure, there is:
It happens to be consistent with the optimization objective of DDPM, so we can directly reuse the model trained by DDPM.
\(p_{\theta}\) The sampling steps are then:
honorific title\(\sigma_t=\eta \sqrt{\dfrac{(1-{\alpha}_{t})(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_{t}}}\)
(coll.) fail (a student)\(\eta =1\) When the forward process is Markovian, the sampling process becomes DDPM.
(coll.) fail (a student)\(\eta =0\) When the sampling process is a deterministic process, the model at this time is called implicit probabilstic model (implicit probabilstic model).
How DDIM accelerates sampling:
In DDPM, based on Markov chains\(t\) together with\(t-1\) is an adjacency, e.g.\(t=100\) imitate\(t-1=99\);
In DDIM.\(t\) together with\(t-1\) Only before and after relationships are indicated, e.g.\(t=100\) when\(t-1\) It can be 90 or 80 or 70, just make sure that the\(t-1 < t\) Ready to go.
The sampling sequence constructed at this point\(\tau=[\tau_i,\tau_{i-1},\cdots,\tau_{1}] \ll [t,t-1,\cdots,1]\) 。
For example, the original sequence\(\Tau=[100,99,98,\cdots,1]\)The sampling sequence is\(\tau=[100,90,80,\cdots,1]\) 。
The DDIM sampling equation is:
(coll.) fail (a student)\(\eta= 0\) When the DDIM sampling equation is:
code implementation
The training procedure is the same as for DDPM, and the code is referenced from the previous article. The sampling code is as follows:
device = 'cuda'
.empty_cache()
model = Unet().to(device)
model.load_state_dict(('ddpm_T1000_l2_epochs_300.pth'))
()
image_size=96
epochs = 500
batch_size = 128
T=1000
betas = (0.0001, 0.02, T).to('cuda') # ([1000])
# at intervals of20sample size of one
tau_index = list(reversed(range(0, T, 20))) #[980, 960, ..., 20, 0]
eta = 0.003
# train
alphas = 1 - betas # 0.9999 -> 0.98
alphas_cumprod = (alphas, axis=0) # 0.9999 -> 0.0000
sqrt_alphas_cumprod = (alphas_cumprod)
sqrt_one_minus_alphas_cumprod = (1-alphas_cumprod)
def get_val_by_index(val, t, x_shape):
batch_t = [0]
out = (-1, t)
return (batch_t, *((1,) * (len(x_shape) - 1))) # ([batch_t, 1, 1, 1])
def p_sample_ddim(model):
def step_denoise(model, x_tau_i, tau_i, tau_i_1):
sqrt_alphas_bar_tau_i = get_val_by_index(sqrt_alphas_cumprod, tau_i, x_tau_i.shape)
sqrt_alphas_bar_tau_i_1 = get_val_by_index(sqrt_alphas_cumprod, tau_i_1, x_tau_i.shape)
denoise = model(x_tau_i, tau_i)
if eta == 0:
sqrt_1_minus_alphas_bar_tau_i = get_val_by_index(sqrt_one_minus_alphas_cumprod, tau_i, x_tau_i.shape)
sqrt_1_minus_alphas_bar_tau_i_1 = get_val_by_index(sqrt_one_minus_alphas_cumprod, tau_i_1, x_tau_i.shape)
x_tau_i_1 = sqrt_alphas_bar_tau_i_1 / sqrt_alphas_bar_tau_i * x_tau_i \
+ (sqrt_1_minus_alphas_bar_tau_i_1 - sqrt_alphas_bar_tau_i_1 / sqrt_alphas_bar_tau_i * sqrt_1_minus_alphas_bar_tau_i) \
* denoise
return x_tau_i_1
sigma = eta * ((1-get_val_by_index(alphas, tau_i, x_tau_i.shape)) * \
(1-get_val_by_index(sqrt_alphas_cumprod, tau_i_1, x_tau_i.shape)) / get_val_by_index(sqrt_one_minus_alphas_cumprod, tau_i, x_tau_i.shape))
noise_z = torch.randn_like(x_tau_i, device=x_tau_i.device)
# The whole equation consists of three parts
c1 = sqrt_alphas_bar_tau_i_1 / sqrt_alphas_bar_tau_i * (x_tau_i - get_val_by_index(sqrt_one_minus_alphas_cumprod, tau_i, x_tau_i.shape) * denoise)
c2 = (1 - get_val_by_index(alphas_cumprod, tau_i_1, x_tau_i.shape) - sigma) * denoise
c3 = sigma * noise_z
x_tau_i_1 = c1 + c2 + c3
return x_tau_i_1
img_pred = ((4, 3, image_size, image_size), device=device)
for k in range(0, len(tau_index)):
# print(tau_index)
# on account of tau_index It's in reverse order.,tau_i = k, tau_i_1 = k+1,You can't get it the other way around.
tau_i_1 = ([tau_index[k+1]], device=device, dtype=)
tau_i = ([tau_index[k]], device=device, dtype=)
img_pred = step_denoise(model, img_pred, tau_i, tau_i_1)
.empty_cache()
if tau_index[k+1] == 0: return img_pred
return img_pred
with torch.no_grad():
img = p_sample_ddim(model)
img = (img, -1.0, 1.0)
show_img_batch(().cpu())
DDIM
/pdf/2010.02502
/ermongroup/ddim