Multistep Distillation of Diffusion Models via Moment Matching
Paper | Unofficial code
Quite old, but still an interesting Google DeepMind paper on accelerating diffusion models: the authors propose distilling a many-step diffusion sampler into a few-step stochastic sampler using moment matching.
The motivation is straightforward: diffusion models achieve strong generation quality, but sampling is expensive and often requiring hundreds or thousands of neural function evaluations (NFEs). The paper shows that sampling can be reduced to roughly 1–8 steps, while in some settings the distilled student even outperforms the original many-step teacher in FID.
Paper | Unofficial code
Quite old, but still an interesting Google DeepMind paper on accelerating diffusion models: the authors propose distilling a many-step diffusion sampler into a few-step stochastic sampler using moment matching.
The motivation is straightforward: diffusion models achieve strong generation quality, but sampling is expensive and often requiring hundreds or thousands of neural function evaluations (NFEs). The paper shows that sampling can be reduced to roughly 1–8 steps, while in some settings the distilled student even outperforms the original many-step teacher in FID.
❤1
🔬 Method
The key idea is that for few-step distillation, it is not enough to simply predict the mean of the next denoising step. When taking large jumps in time, the student has to match the distribution induced by the teacher along the sampling trajectory.
The central condition is:
if the student generates
In practice:
🔅 start from a noisy data state
🔅 let the student take a coarse reverse step and generate
🔅 sample
🔅 compare the generated-side moment with the teacher denoiser
🔅 train the student so this moment gap disappears
A crucial detail: in the multistep setting, the transition to
(The attached image was taken from lecture nodes of YSDA)
The key idea is that for few-step distillation, it is not enough to simply predict the mean of the next denoising step. When taking large jumps in time, the student has to match the distribution induced by the teacher along the sampling trajectory.
The central condition is:
if the student generates
x̃ after a coarse reverse step from z_t, then after moving to a cleaner state z_s, the generated-side denoising moment should match the data-side denoising moment:E_g[x̃ | z_s] = E_q[x | z_s]In practice:
🔅 start from a noisy data state
z_t🔅 let the student take a coarse reverse step and generate
x̃🔅 sample
z_s conditionally on both`z_t` and x̃🔅 compare the generated-side moment with the teacher denoiser
g_θ(z_s, s)🔅 train the student so this moment gap disappears
A crucial detail: in the multistep setting, the transition to
z_s should be conditional, i.e. q(z_s | x̃, z_t), rather than just adding fresh noise to x̃. Otherwise, diversity drops significantly, which was shown by authors in the article.(The attached image was taken from lecture nodes of YSDA)
❤1
⚙️ Two algorithmic variants
1. Alternating moment matching
This version introduces an auxiliary denoiser
Training alternates between two steps:
🔅 train the auxiliary denoiser to predict generated samples, with regularization toward the teacher
🔅 train the student using the difference between the auxiliary denoiser and the teacher denoiser:
Conceptually, this resembles alternating training, but instead of a discriminator, the auxiliary model estimates denoising moments.
2. Instant moment matching
The second version removes the need for a separate auxiliary denoiser. Instead of explicitly matching moments in data space, it matches teacher-gradient moments in parameter space.
The intuition: if the generated samples follow the correct distribution, the teacher should not “want” to update differently on generated samples than it does on real data. In practice, this uses a two-minibatch estimator and a Jacobian-vector product through the teacher. Only the student is updated.
1. Alternating moment matching
This version introduces an auxiliary denoiser
g_φ, which learns to estimate the generated-side moment:g_φ(z_s, s) ≈ E_g[x̃ | z_s]Training alternates between two steps:
🔅 train the auxiliary denoiser to predict generated samples, with regularization toward the teacher
🔅 train the student using the difference between the auxiliary denoiser and the teacher denoiser:
g_φ(z_s, s) - g_θ(z_s, s)Conceptually, this resembles alternating training, but instead of a discriminator, the auxiliary model estimates denoising moments.
2. Instant moment matching
The second version removes the need for a separate auxiliary denoiser. Instead of explicitly matching moments in data space, it matches teacher-gradient moments in parameter space.
The intuition: if the generated samples follow the correct distribution, the teacher should not “want” to update differently on generated samples than it does on real data. In practice, this uses a two-minibatch estimator and a Jacobian-vector product through the teacher. Only the student is updated.
❤1
🧪 Experiments
The ImageNet results look very strong by the time of the article was published. Future works, such as DMD2, works better
ImageNet 64×64
Base diffusion model:
• 1024 NFE
• FID 1.42
Moment Matching, alternating:
• 8 NFE
• FID 1.24
So the 8-step student achieves better FID than the 1024-step teacher.
ImageNet 128×128
Base diffusion model:
• 1024 NFE
• FID 1.76
Moment Matching, alternating:
• 8 NFE
• FID 1.49
The same pattern holds at higher resolution: the few-step student does not merely accelerate sampling, but can also improve FID over the long-chain teacher.
The authors also evaluate text-to-image generation at 512×512:
• pixel-space UViT
• no autoencoder or upsampler
• T5 XXL text encoder
• evaluated with zero-shot MS-COCO FID30k and CLIP Score
Best result shown in the presentation:
• Moment Matching, alternating, guidance 0
• 8 NFE
• FID30k 7.25
This is better than the base model with guidance 0.5 at 512 NFE, which gets FID30k 7.9.
The ImageNet results look very strong by the time of the article was published. Future works, such as DMD2, works better
ImageNet 64×64
Base diffusion model:
• 1024 NFE
• FID 1.42
Moment Matching, alternating:
• 8 NFE
• FID 1.24
So the 8-step student achieves better FID than the 1024-step teacher.
ImageNet 128×128
Base diffusion model:
• 1024 NFE
• FID 1.76
Moment Matching, alternating:
• 8 NFE
• FID 1.49
The same pattern holds at higher resolution: the few-step student does not merely accelerate sampling, but can also improve FID over the long-chain teacher.
The authors also evaluate text-to-image generation at 512×512:
• pixel-space UViT
• no autoencoder or upsampler
• T5 XXL text encoder
• evaluated with zero-shot MS-COCO FID30k and CLIP Score
Best result shown in the presentation:
• Moment Matching, alternating, guidance 0
• 8 NFE
• FID30k 7.25
This is better than the base model with guidance 0.5 at 512 NFE, which gets FID30k 7.9.
❤1
💡 Takeaways
Moment Matching Distillation looks like a strong way to turn a heavy diffusion sampler into a fast stochastic generator with only 1–8 sampling steps.
What stands out:
🔅 The method does not simply learn a deterministic shortcut through the denoising chain; it tries to match the distribution via conditional denoising moments.
🔅 The multistep formulation is much stronger than one-step variants.
🔅 In the 8-step regime, the student can outperform the many-step teacher in FID.
🔅 The method scales beyond class-conditional ImageNet to large-scale text-to-image generation.
🔅 One possible explanation for outperforming the teacher: teacher predictions across timesteps are not necessarily perfectly consistent, while a coarse stochastic student may smooth over or avoid some of these errors.
Overall, this looks like a useful technique for speeding up diffusion inference, especially when we want to preserve stochasticity and quality instead of learning a purely deterministic shortcut.
Moment Matching Distillation looks like a strong way to turn a heavy diffusion sampler into a fast stochastic generator with only 1–8 sampling steps.
What stands out:
🔅 The method does not simply learn a deterministic shortcut through the denoising chain; it tries to match the distribution via conditional denoising moments.
🔅 The multistep formulation is much stronger than one-step variants.
🔅 In the 8-step regime, the student can outperform the many-step teacher in FID.
🔅 The method scales beyond class-conditional ImageNet to large-scale text-to-image generation.
🔅 One possible explanation for outperforming the teacher: teacher predictions across timesteps are not necessarily perfectly consistent, while a coarse stochastic student may smooth over or avoid some of these errors.
Overall, this looks like a useful technique for speeding up diffusion inference, especially when we want to preserve stochasticity and quality instead of learning a purely deterministic shortcut.
❤1
🧭 Why I’m especially interested in this paper
The main reason I wanted to highlight this work is that moment matching did not stop at continuous diffusion models. The same authors later developed this direction further for Discrete Diffusion Models in:
Beyond Single Tokens: Distilling Discrete Diffusion Models via Discrete MMD
This is particularly relevant for us because it is closely connected to our recent work:
IDLM: Inverse-distilled Diffusion Language Models — ICML 2026
Both lines of work are trying to solve a similar bottleneck: discrete diffusion models and diffusion language models can be high-quality, but inference is still expensive because generation usually requires many iterative sampling steps. So the key question is how to distill these models into much faster few-step generators without collapsing quality or diversity.
We will discuss this broader direction at the Popular Reading Group meeting devoted to Discrete Diffusion Models on Monday, May 18.
Please join the discussion chat
The meeting link will be shared later in our Telegram chat
The main reason I wanted to highlight this work is that moment matching did not stop at continuous diffusion models. The same authors later developed this direction further for Discrete Diffusion Models in:
Beyond Single Tokens: Distilling Discrete Diffusion Models via Discrete MMD
This is particularly relevant for us because it is closely connected to our recent work:
IDLM: Inverse-distilled Diffusion Language Models — ICML 2026
Both lines of work are trying to solve a similar bottleneck: discrete diffusion models and diffusion language models can be high-quality, but inference is still expensive because generation usually requires many iterative sampling steps. So the key question is how to distill these models into much faster few-step generators without collapsing quality or diversity.
We will discuss this broader direction at the Popular Reading Group meeting devoted to Discrete Diffusion Models on Monday, May 18.
Please join the discussion chat
The meeting link will be shared later in our Telegram chat
❤1
Hi everyone!
I’m David Li, a first-year PhD student at Mohamed bin Zayed University of Artificial Intelligence (MBZUAI). My research interests are mainly in generative models, diffusion models, optimal transport, and related areas of machine learning.
I created LiSearch to share short notes about new papers, interesting ideas, and possible research directions in generative modeling and ML.
The main motivation for this channel is discussion. I don’t want it to be just a list of paper summaries, I’d like it to become a place where researchers and ML enthusiasts can exchange thoughts, ask questions, criticize ideas, suggest papers, and discuss what may be worth exploring next.
I’ll be very glad to see any activity here: comments, questions, opinions, links to papers, or your own research ideas.
A bit more about me:
Google Scholar: https://scholar.google.com/citations?hl=en&user=L88Qc4YAAAAJ
LinkedIn: https://www.linkedin.com/in/david-li-ab07b332b
Telegram: @kekchpek
Welcome to LiSearch!
I’m David Li, a first-year PhD student at Mohamed bin Zayed University of Artificial Intelligence (MBZUAI). My research interests are mainly in generative models, diffusion models, optimal transport, and related areas of machine learning.
I created LiSearch to share short notes about new papers, interesting ideas, and possible research directions in generative modeling and ML.
The main motivation for this channel is discussion. I don’t want it to be just a list of paper summaries, I’d like it to become a place where researchers and ML enthusiasts can exchange thoughts, ask questions, criticize ideas, suggest papers, and discuss what may be worth exploring next.
I’ll be very glad to see any activity here: comments, questions, opinions, links to papers, or your own research ideas.
A bit more about me:
Google Scholar: https://scholar.google.com/citations?hl=en&user=L88Qc4YAAAAJ
LinkedIn: https://www.linkedin.com/in/david-li-ab07b332b
Telegram: @kekchpek
Welcome to LiSearch!
👍3
I’ve published a video where I explain the paper “Multistep Distillation of Diffusion Models via Moment Matching”
I also made a LinkedIn post where I explain the overall idea behind this channel.
I don’t want to repeat everything here, but the main point is this: for all the “beauty” parts, such as YouTube video icons, thumbnails, and similar visuals, I’ll use neural networks. I don’t want to spend too much time on that manually, so that’s why you may see some funny faces or weird-looking icons in the YouTube videos😁
I’ll try to keep the meetings weekly and post new videos every week. If something doesn’t work out, I’ll announce it separately.
I hope this video helps you understand the paper better. I’ll be happy to discuss any ideas, questions, or thoughts in the comments!
I also made a LinkedIn post where I explain the overall idea behind this channel.
I don’t want to repeat everything here, but the main point is this: for all the “beauty” parts, such as YouTube video icons, thumbnails, and similar visuals, I’ll use neural networks. I don’t want to spend too much time on that manually, so that’s why you may see some funny faces or weird-looking icons in the YouTube videos
I’ll try to keep the meetings weekly and post new videos every week. If something doesn’t work out, I’ll announce it separately.
I hope this video helps you understand the paper better. I’ll be happy to discuss any ideas, questions, or thoughts in the comments!
Please open Telegram to view this post
VIEW IN TELEGRAM
🔥7❤3
IDLM: Inverse-distilled Diffusion Language Models (ICML 2026, Our recent work)
Paper | Code | Checkpoints
⚡️ Can a language model generate a 1024-token sequence in just 16 forward passes?
That would mean producing 1024/16=64 tokens per forward pass.
For today’s standard language models, this sounds almost impossible. They are autoregressive, meaning they generate text token by token: first token, then the next, then the next…
So generating 1024 tokens usually requires 1024 forward passes.
This is one of the biggest bottlenecks in LLM inference.
A promising alternative is Diffusion Language Models. Instead of generating tokens one by one, they try to generate or refine the whole sequence in parallel, potentially removing the need for strict autoregressive decoding.
In theory, this could make generation much faster.
But in practice, diffusion-based language models often turn out to be slower, not faster, than autoregressive models.
The main challenge is the space dimension.
If we want the model to generate the next 64 tokens at once, it is not enough to predict one token 64 times independently. Ideally, the model should approximate the joint distribution over all possible 64-token continuations.
But the number of such continuations is enormous.
Even for a relatively small vocabulary, like GPT-2’s vocabulary of about ≈60,000 tokens, the number of possible 64-token sequences is:
60,000⁶⁴ ≈ 10³⁰⁶
That is an astronomically large number. We cannot enumerate these possibilities, store them, or explicitly simulate such a distribution.
So the real question becomes:
Can we generate many tokens in parallel while keeping the model’s complexity only linear in sequence length?
Paper | Code | Checkpoints
⚡️ Can a language model generate a 1024-token sequence in just 16 forward passes?
That would mean producing 1024/16=64 tokens per forward pass.
For today’s standard language models, this sounds almost impossible. They are autoregressive, meaning they generate text token by token: first token, then the next, then the next…
So generating 1024 tokens usually requires 1024 forward passes.
This is one of the biggest bottlenecks in LLM inference.
A promising alternative is Diffusion Language Models. Instead of generating tokens one by one, they try to generate or refine the whole sequence in parallel, potentially removing the need for strict autoregressive decoding.
In theory, this could make generation much faster.
But in practice, diffusion-based language models often turn out to be slower, not faster, than autoregressive models.
The main challenge is the space dimension.
If we want the model to generate the next 64 tokens at once, it is not enough to predict one token 64 times independently. Ideally, the model should approximate the joint distribution over all possible 64-token continuations.
But the number of such continuations is enormous.
Even for a relatively small vocabulary, like GPT-2’s vocabulary of about ≈60,000 tokens, the number of possible 64-token sequences is:
60,000⁶⁴ ≈ 10³⁰⁶
That is an astronomically large number. We cannot enumerate these possibilities, store them, or explicitly simulate such a distribution.
So the real question becomes:
Can we generate many tokens in parallel while keeping the model’s complexity only linear in sequence length?
🔥7
The answer is yes and the key idea is a mixture of distributions.
Instead of trying to explicitly model all possible 64-token continuations, we introduce a latent variable ε sampled from a simple latent space.
Then the model generates tokens independently conditioned on ε.
At first glance, this looks too simple: independent tokens cannot capture complex text structure, right?
But the important part is that tokens are independent only after conditioning on ε.
The final text distribution is obtained by averaging over all latent variables (see the first attached image).
So the model is actually a mixture of many factorized distributions.
And this is powerful: the generator Gθ can encode global structure, style, topic, dependencies, and correlations between tokens thorugh the latent variable ε. As a result, the marginal distribution over text can still be highly expressive, even though each conditional distribution is factorized.
This direction has already been explored in recent works such as Di4C (ICML 2025) and VADD (ICLR 2026).
The last image is a great illustration from VADD. Without a latent variable, a factorized model fails to capture dependencies, this is what happens with MDLM. But with a latent variable, VADD can recover structured distributions like checkerboards and spirals.
However, there is still a major problem.
These mixtures in VADD is trained with a VAE-style objective. And in practice, VAE losses can be fragile: they require balancing reconstruction quality against the regularization term. If this balance is not right, the model can learn poor latent representations and produce weak samples.
So the real question becomes:
Can we design a better loss function for training mixtures of distributions?
Instead of trying to explicitly model all possible 64-token continuations, we introduce a latent variable ε sampled from a simple latent space.
Then the model generates tokens independently conditioned on ε.
At first glance, this looks too simple: independent tokens cannot capture complex text structure, right?
But the important part is that tokens are independent only after conditioning on ε.
The final text distribution is obtained by averaging over all latent variables (see the first attached image).
So the model is actually a mixture of many factorized distributions.
And this is powerful: the generator Gθ can encode global structure, style, topic, dependencies, and correlations between tokens thorugh the latent variable ε. As a result, the marginal distribution over text can still be highly expressive, even though each conditional distribution is factorized.
This direction has already been explored in recent works such as Di4C (ICML 2025) and VADD (ICLR 2026).
The last image is a great illustration from VADD. Without a latent variable, a factorized model fails to capture dependencies, this is what happens with MDLM. But with a latent variable, VADD can recover structured distributions like checkerboards and spirals.
However, there is still a major problem.
These mixtures in VADD is trained with a VAE-style objective. And in practice, VAE losses can be fragile: they require balancing reconstruction quality against the regularization term. If this balance is not right, the model can learn poor latent representations and produce weak samples.
So the real question becomes:
Can we design a better loss function for training mixtures of distributions?
🔥2
