While diffusion models excel at generating high-quality images, prior work reports a significant performance gap between diffusion and autoregressive (AR) methods on language modeling. In this work, we show that simple masked discrete diffusion is more performant than previously thought. We apply an effective training recipe that improves the performance of masked diffusion models and derive a simplified, Rao-Blackwellized objective that results in additional improvements. Our objective has a simple form—it is a mixture of classical masked language modeling losses— and can be used to train encoder-only language models that admit efficient samplers, including ones that can generate arbitrary lengths of text semi-autoregressively like a traditional language model. On language modeling benchmarks, a range of masked diffusion models trained with modern engineering practices achieves a new state-of-the-art among diffusion models, and approaches AR perplexity.
Diffusion models excel at producing realistic, high-quality images and have received significant attention as potential tools for generating discrete data such as text, biological sequences, and graphs. Unlike autoregressive (AR) approaches, diffusion-based methods are not constrained to generate data sequentially, and therefore have the potential to improve long-term planning, controllable generation, and sampling speed. However, discrete diffusion methods exhibit a performance gap relative to AR models, especially in language modeling. The standard measure of language modeling performance is log-likelihood: when controlling for parameter count, prior work reports a sizable log-likelihood gap between AR and diffusion models.
Applications of diffusion modeling to discrete data can be categorized into two broad areas. The first involves embedding discrete structures in continuous space and then performing the Gaussian diffusion defined above on these continuous representations. More related to our method are works that define a diffusion process directly on discrete structures. D3PM introduces a framework with a Markov forward process \( q(z_t|z_{t−1}) = \text{Cat}(z_t; Q_t z_{t−1}) \), defined by the multiplication of matrices \( Q_t \in \mathbb{R}^{n \times n} \) over \( T \) discrete time steps. The matrix \( Q_t \) is designed such that \( Q_T \cdot Q_{T-1} \cdots Q_1 \mathbf{x} \) converges to a stationary distribution.
While previous work on discrete diffusion supports general forward processes (e.g., general \(Q_t\) in D3PM), absorbing state (i.e., masking) diffusion consistently achieves the best performance. In this work, instead of supporting general noise processes, we focus on masking and derive tight Rao-Blackwellized objectives that outperform general approaches and do not require CTMC theory (for ex. SEDD). We denote our overall approach as Masked Diffusion Language Models (MDLM).
We forcus on forward processes q that interpolate between clean data \(\mathbf{x} \in \mathcal{V}\), where \(\mathcal{V}\) is a set of all one-hot vectors, and a target distribution \(\text{Cat}(.; \boldsymbol{\mathit{\pi}})\). This approach is a direct extention of Gaussian diffusion where the intermediate sample interpolates between the clean data and white noise. The \(q\) defines a sequence of increasingly noisy latent variables \(\mathbf{z}_t \in \mathcal{V}\), where the time step \(t\) runs from \(t = 0\) (least noisy) to \(t = 1\) (most noisy). The marginal of \(\mathbf{z}_t\) conditioned on \(\mathbf{x}\) at time \(t\) is given by \(q(\mathbf{z}_t|x) = \text{Cat}(\mathbf{z}_t;\alpha_t \mathbf{x}+(1 - \alpha_t) \boldsymbol{\mathit{\pi}})\), where \(\alpha_t \in [0, 1]\) is a strictly decreasing function in t, with \(\alpha_{t=0} \approx 0\) and \(\alpha_{t=1} \approx 1\). In absorbing state diffusion, we set \(\boldsymbol{\mathit{\pi}} = \mathbf{m}\) where \(\mathbf{m} \in \mathcal{V}\) and \(\mathbf{m}_K=1\), with \(K^\text{th}\)category representing the special masked token, [MASK].
The specific parameterization for \(p_\theta(\mathbf{z}_s | \mathbf{z}_t)\) with \(0 \leq s < t \leq 1\) that we use is: \begin{align}\label{eqn:approx_posterior} p_\theta(\mathbf{z}_s | \mathbf{z}_t) = q(\mathbf{z}_s | \mathbf{z}_t, \mathbf{x} = \mathbf {x}_\theta (\mathbf{z}_t, t)) = \begin{cases} \text{Cat} (\mathbf{z}_s; \mathbf{z}_t), & \mathbf{z}_t \neq \mathbf{m}, \\ \text{Cat} \left( \mathbf{z}_s; \frac{ (1 - \alpha_s)\mathbf{m} + (\alpha_s - \alpha_t) \mathbf{x}_\theta (\mathbf{z}_t, t)}{1 - \alpha_t}\right). & \mathbf{z}_t= \mathbf{m}, \end{cases} \end{align} Furthermore, we induce 2 key properties of the absorbing state diffusion process into our denoising model, \( \mathbf{x}_\theta (\mathbf{z}_t, t): \mathcal{V} \times [0, 1] \rightarrow \Delta^{K}\), where \(\Delta^{K}\) denotes the \(K\)-simplex:
For a sequence \(\mathbf{x}^{1: L}\) of length \(L\), SUBS parameterization simplifies the Negative Evidence Lower Bound (NELBO) to the following: \begin{align} \mathcal{L}_{\text{NELBO}}^\infty = \mathbb{E}_{q}\int_{t=0}^{t=1} \frac{\alpha_{t}'}{1 - \alpha_t} \sum_{\ell = 1}^{L} \log \langle \mathbf {x}_\theta^\ell(\mathbf{z}_t), \mathbf{x}^\ell \rangle \text{d} t \end{align} where \(\alpha_{t}'\) denotes the time derivative of \(\alpha_{t}\). As shown in the table below, MDLM outperforms the previous diffusion models and nearly matches the performance of AR models in text generation on the LM1B dataset.
Parameters | PPL (↓) | ||
---|---|---|---|
Autoregressive | Transformer-X Base | 0.46B | 23.5 |
OmniNetT | 100M | 21.5 | |
Diffusion | BERT-Mouth | 110M | ≤142.89 |
D3PM (absorb) | 70M | ≤77.50 | |
Diffusion-LM | 80M | ≤118.62 | |
DiffusionBert | 110M | ≤63.78 | |
SEDD | 110M | ≤32.79 | |
Autoregressive (Retrained) | Transformer (33B tokens) | 110M | 22.32 |
Transformer (327B tokens) | 20.86 | ||
Diffusion (Ours) | MDLM (33B tokens) | 110M | ≤27.04 |
MDLM (327B tokens) | ≤23.00 |
PTB | Wikitext | LM1B | Lambada | AG News | Pubmed | Arxiv | |
---|---|---|---|---|---|---|---|
AR (Retrained) | 82.05 | 25.75 | 51.25 | 51.28 | 52.09 | 49.01 | 41.73 |
SEDD (Retrained) | 100.09 | 34.28 | 68.20 | 49.86 | 62.09 | 44.53 | 38.48 |
MDLM (Ours) | 95.26 | 32.83 | 67.01 | 47.52 | 61.15 | 41.89 | 37.37 |
One of the key contributions of our work is a well-engineered implementation of masked diffusion models. Our experiments demonstrate that these improvements greatly boost performance even for methods previously thought to perform poorly, e.g., D3PM . Below we briefly summarize these implementation details.
@misc{sahoo2024simple,
title={Simple and Effective Masked Diffusion Language Models},
author={Subham Sekhar Sahoo and Marianne Arriola and Yair Schiff and Aaron Gokaslan and Edgar Marroquin and Justin T Chiu and Alexander Rush and Volodymyr Kuleshov},
year={2024},
eprint={2406.07524},
archivePrefix={arXiv},
primaryClass={cs.CL}
}