Implementing Denoising Diffusion Models from Scratch using PyTorch

1 minute read

Published:

Denoising Diffusion Models are a powerful class of generative models that transform data distributions into noise distributions and then learn to reverse this process to generate new data. They build upon concepts from Hierarchical Variational Autoencoders (HVAEs) and utilize Markov chains for their forward and reverse diffusion processes. This article provides a detailed guide on implementing a denoising diffusion model from scratch, covering the theoretical background, training methodology, and sampling process.

From HVAE to Diffusion Model

Hierarchical Variational Autoencoders (HVAEs) generalize VAEs by introducing multiple hierarchies over latent variables. The Markovian assumption is applied to simplify the model, making it more manageable.

Diffusion Processes in Generative Models

A diffusion process in generative models is a Markov chain that progressively transforms a data distribution into a known noise distribution. The model then learns to reverse this process to generate new data.

Forward Diffusion Process

The forward diffusion process is a fixed Markov chain that gradually adds Gaussian noise to the data according to a variance schedule. The variances can be learned or set as hyperparameters.

Reverse Diffusion Process

The reverse process is defined as a Markov chain with learned Gaussian transitions starting at a standard normal distribution.

Training

Training involves developing an Evidence Lower Bound (ELBO) to estimate the marginal distribution. Efficient training can be achieved through stochastic gradient descent and variance reduction techniques.

Training Algorithm

  1. Sample from the data distribution.
  2. Randomly select a time step.
  3. Sample noise from a standard normal distribution.
  4. Take a gradient descent step on the objective.

Sampling

The sampling process involves generating samples from a probabilistic model by reversing the diffusion process.

  1. Initialize from a standard normal distribution.
  2. Loop over time steps, updating the sample at each step.
  3. Return the final generated sample.

To read the entire article and learn how to implement it using PyTorch, click here.