MDLM: Simple and Effective Masked Diffusion Language Models
MDLM(Masked Diffusion Language Model)は、Sahoo らが NeurIPS 2024 で発表した離散拡散言語モデルの定式化である (Sahoo ほか 2024年)。離散拡散の foundational な数学を提供した D3PM (Austin ほか 2021年) から数年間にわたって蓄積されてきた masked diffusion の理論を、1 つの簡潔な目的関数に集約した点が最大の貢献である。後続の LLaDA (Nie ほか 2025年) や Dream は実質的にこの定式化の上に立っており、現代的な Diffusion Language Model(DLLM)を理解する出発点として最初に読むべき論文となっている。
図 1 に示すように、訓練は「マスク率を確率変数として変動させた BERT」と読める。
なぜ MDLM を最初に読むべきか
D3PM が提示した離散拡散の枠組みは、uniform / absorbing / discretized Gaussian など多様な遷移行列を統一的に扱う一方、目的関数は KL ダイバージェンスの和として書かれており、実装上の見通しが必ずしも良くなかった。その後 SEDD (Lou ほか 2024年) は concrete score / ratio matching の観点から目的関数を再構築したが、score function を陽に扱うため離散領域特有の煩雑さが残っていた。
MDLM はこの状況に対し、absorbing transition([MASK] 状態への一方向遷移)に絞った上で連続時間極限を取ると、ELBO が 重み \(1/t\) の masked cross-entropy に縮約することを示した。これにより、
- 訓練は BERT のランダムマスク予測の連続時間一般化として実装できる
- 推論は離散時刻でのサンプリングループとして書ける
- 連続拡散モデル(DDPM 等)の denoising score matching と構造的に対応する
という見通しの良い構図が得られる。「DLLM の訓練とは何をしているのか」を 1 本で掴むうえで、MDLM ほど良い入口は他にない。
MDLM と同年に Shi らも独立に類似の masked diffusion の定式化を提案している (Shi ほか 2024年)。両者は記法こそ異なるが、本質的に同じ「absorbing + 連続時間 ELBO → 重み付き masked CE」の構造に到達しており、現在は両論文を合わせて参照することが多い。
定式化の核
記法
語彙サイズを \(V\)、系列長を \(L\) とし、クリーンな系列を \(x_0 = (x_0^1, \dots, x_0^L)\) で表す。各トークンは \(V+1\) 状態を取り得る(通常の語彙 \(V\) 種類に加え、特殊トークン [MASK] を追加)。\(x_t^i\) は時刻 \(t \in [0,1]\) における位置 \(i\) のトークンを表す。
Forward 過程
各位置のトークンを独立に、確率 \(t\) で [MASK] に置換する。\(t = 0\) では \(x_0\) がそのまま残り、\(t = 1\) では全位置が [MASK] になる。各位置 \(i\) について、
\[ q(x_t^i \mid x_0^i) = \begin{cases} 1 - t & x_t^i = x_0^i \\ t & x_t^i = \texttt{[MASK]} \end{cases} \]
これは離散時刻の D3PM における absorbing transition の連続時間版に相当する。一度 [MASK] になった位置は、forward 過程の途中で元のトークンに戻ることはない(absorbing 性質)。
Reverse 過程
逆過程は、\(t\) から \(t - \mathrm{d}t\) への 1 ステップで、[MASK] 位置のうち一部を予測トークンで埋める過程として書ける。位置 \(i\) が時刻 \(t\) で [MASK] のとき、時刻 \(s < t\) における条件付き分布は
\[ q(x_s^i \mid x_t^i = \texttt{[MASK]}, x_0^i) = \begin{cases} \frac{t - s}{t} & x_s^i = x_0^i \\ \frac{s}{t} & x_s^i = \texttt{[MASK]} \end{cases} \]
となる。実装上はこの真の posterior に対して、\(x_0\) をニューラルネット \(p_\theta(x_0 \mid x_t)\) で予測することで近似する。すなわち \(x_0\)-prediction の精神で逆過程を学習する。
目的関数
連続時間 ELBO の積分を、上記の forward / reverse 過程の選択のもとで具体化すると、最終的に次の損失関数に縮約する。
定理 1 (MDLM の目的関数 (Sahoo ほか 2024年)) 連続時間 absorbing forward 過程の下で、MDLM の負の ELBO は次の損失と等価になる。
\[ \mathcal{L}_\text{MDLM} = \mathbb{E}_{t \sim \mathcal{U}(0,1)} \, \mathbb{E}_{x_t \sim q(\cdot \mid x_0)} \left[ \frac{1}{t} \sum_{i=1}^{L} \mathbf{1}[x_t^i = \texttt{[MASK]}] \, \log p_\theta(x_0^i \mid x_t) \right] \tag{1}\]
ここで本質的なポイントは次の 2 点である。
[MASK]位置でのみ評価される: \(\mathbf{1}[x_t^i = \texttt{[MASK]}]\) により、unmasked 位置の loss は寄与しない- 重み \(1/t\): 時刻 \(t\) が小さい(マスクが少ない)ほど 1 マスクあたりの寄与が大きい
定理 1 は、BERT のランダムマスク予測損失に対し「マスク率を \(t \in [0,1]\) で動かしながら、\(1/t\) で重み付ける」という拡張になっている。これが「MDLM は連続時間版 BERT である」と言われる所以である。
学習目的の導出の流れ
詳細な計算は原論文 §3 および Appendix A に譲るが、(式 1) に至る論理の骨格は次の通りである。
ELBO の離散時刻版
時刻を \(0 = t_0 < t_1 < \dots < t_N = 1\) と離散化すると、ELBO は
\[ \log p_\theta(x_0) \geq -\mathbb{E}_q \left[ \sum_{n=1}^{N} D_\text{KL}\left( q(x_{t_{n-1}} \mid x_{t_n}, x_0) \,\big\|\, p_\theta(x_{t_{n-1}} \mid x_{t_n}) \right) \right] + \text{const.} \]
と書ける。各 KL は位置ごとに分解できる(forward 過程が位置独立なため)。
各位置の KL の評価
位置 \(i\) における KL は、\(x_{t_n}^i\) が [MASK] か通常トークンかで場合分けされる。
- \(x_{t_n}^i \ne \texttt{[MASK]}\)(既に確定): absorbing 性質より \(x_{t_{n-1}}^i = x_{t_n}^i\) が確定的に分かるため、posterior と prior の両方が delta になり KL は 0
- \(x_{t_n}^i = \texttt{[MASK]}\): posterior は「\(x_0^i\) で埋まる確率 \((t_n - t_{n-1})/t_n\)、
[MASK]のまま残る確率 \(t_{n-1}/t_n\)」であり、ここでのみ非自明な KL が発生
unmasked 位置の loss が消えるのは、まさに absorbing 性質に由来する。情報が一度 forward 過程で「噴出」して [MASK] になったあと、それが時刻 \(t\) で残っていれば未確定(loss あり)、消えていれば既確定(loss なし)という単純な構造になる。
連続時間極限
\(N \to \infty\) の極限を取ると、ステップ幅 \(\Delta t = t_n - t_{n-1}\) が 0 に近づき、KL の主要項が
\[ \frac{\Delta t}{t_n} \cdot (- \log p_\theta(x_0^i \mid x_{t_n})) \]
の形に整理される。これを積分すると、時刻 \(t\) について \(1/t\) の重み付けが現れる。\(1/t\) という重みは、絶対時間ではなく「単位時間あたりに forward 過程で何個のトークンが [MASK] に吸収されるか」のレートから来ていると理解しておくと直感が掴みやすい。
推論時の denoising loop
訓練が終わった後、サンプリングは reverse 過程を離散時刻で辿ることで行う。基本形は次の擬似コードで表せる。
# x_T: 全位置 [MASK] で初期化(T はステップ数)
x = [MASK] * L
for t in linspace(1.0, 0.0, T+1)[:-1]:
s = t - 1.0 / T
# ニューラルネット p_theta による予測
logits = model(x)
# MASK 位置のみ予測サンプリング
for i in masked_positions(x):
if rand() < (t - s) / t:
x[i] = sample(softmax(logits[i]))
# else: MASK のまま残すこの基本形では「各 [MASK] 位置を独立に確率 \((t - s)/t\) で確定させる」ことになり、ステップ数 \(T\) を増やせば 1 ステップあたりの確定数が減って品質が上がる。逆に \(T\) を減らせば高速だが品質が落ちる。\(T\) は計算量と品質のトレードオフを支配するハイパーパラメータである。
実用上は「ランダムに確率で確定」ではなく、「予測 logit の信頼度が高い位置から順に確定」させる戦略がよく使われる。これは画像生成の MaskGIT に由来し、LLaDA など後続モデルが採用している。MDLM の理論的な reverse 過程はランダム確定だが、サンプラの選択は理論と独立に交換可能である。
absorbing transition の必然性
D3PM では uniform / absorbing / discretized Gaussian など複数の遷移を扱えたが、MDLM はあえて absorbing に絞る。なぜそれが自然な選択なのか。
「噴出した情報は戻ってこない」一方向性
absorbing 過程の本質は、情報の損失が一方向的であることだ。一度 [MASK] になった位置は、forward 過程の途中で別のトークンに戻ったり、別の通常トークンに変化したりしない。これにより、
- reverse 過程で「現在
[MASK]の位置は元が何であれ未知」「現在通常トークンの位置は元のまま」という単純な区別ができる - 学習目的の評価は
[MASK]位置のみで行えばよい - BERT のマスク予測タスクと直接対応する
uniform 遷移(任意のトークンに置換)の場合、reverse 過程で「現在トークン \(a\) にあるが、それが元の \(a\) なのか別の文字が変化したものなのか」を区別する必要があり、目的関数が複雑になる。absorbing が「単純な目的関数」と「BERT との接続」を同時に成立させる選択である。
言語データとの相性
言語データにおいて「特定のトークンが [MASK] に置換される」は、欠損・伏字・穴埋めという自然な操作と対応する。uniform 置換(別のランダムトークンに置換)よりも、テキスト処理の直感に合う。
実験結果と scaling
MDLM 論文では LM1B・OpenWebText を用いた実験により、次の点を示している。
| 比較対象 | MDLM の位置付け |
|---|---|
| D3PM (absorbing) | より良い perplexity を達成 |
| SEDD | 同等以上、かつ実装が簡潔 |
| AR (GPT-2 同規模) | わずかに劣るが同程度に scale する |
特に重要なのは、AR と同程度の scaling 則に従って性能が伸びることである。すなわち、計算量・データ量・モデルサイズを増やしたときの perplexity の減衰が AR と類似のパターンを示し、DLLM が「小規模でだけ動くオモチャ」ではないことを示唆している。この観察は後の LLaDA(8B スケール)の動機付けにもなっている。
読み方の優先順位
論文を読む際の各セクションの重要度を表にまとめておく。
| セクション | 重要度 | 内容 |
|---|---|---|
| §2 定式化 | 必読(2 周以上) | forward / reverse 過程の定義、記法 |
| §3.1 目的関数の導出 | 必読 | ELBO から \(1/t\) 重み付き CE への簡略化 |
| §3.2 SUBS パラメータ化 | 推奨 | \(x_0\) 予測のヘッド設計、[MASK] 出力を 0 にする工夫 |
| §4 sampling | 必読 | 推論ループ、ancestral / analytic samplers |
| §5 実験 | 概観で十分 | LM1B・OWT・zero-shot perplexity |
| Appendix A | scan で十分 | D3PM 等価性の証明、連続時間極限の厳密化 |
| Appendix B-C | リファレンス | 派生損失・実装詳細 |
特に §2 と §3.1 は、本書の他章を読む際の前提知識となるため、最低 2 周読んで「forward 過程の定義」「ELBO がなぜ [MASK] 位置のみの CE になるか」を自分の言葉で説明できる状態を目指したい。
この論文を読んだ後に分かること
MDLM を一通り読むと、以下の理解が得られる。
- DLLM の訓練の正体: 「ノイズスケジュール付きの BERT 訓練」だと割り切ってよい。マスク率 \(t\) を一様にサンプリングし、\(1/t\) で重み付ける以外、BERT との差は本質的にはない。
- 推論時の denoising step の意味: 離散時刻 \(t_n\) でのサンプリングであり、ステップ数 \(T\) は計算量と品質のトレードオフのハイパラ。理論上は \(T \to \infty\) で連続時間 reverse 過程に近づく。
- absorbing 性質の役割: 「unmask したら確定」という性質が、目的関数を
[MASK]位置のみで評価する形に簡略化する根本原因。 - AR との関係: AR は左から右への 1 方向 unmask(既存トークンの予測を順次行う)と見なせ、DLLM は順序自由な unmask に一般化したもの、と捉えられる。
連続拡散モデルとの対応
連続拡散(DDPM, VP-SDE 等)における denoising score matching(DSM)は、ノイズ強度 \(\sigma_t\) で重み付けされた L2 損失
\[ \mathcal{L}_\text{DSM} = \mathbb{E}_t \, \mathbb{E}_{x_t} \left[ w(t) \, \| s_\theta(x_t, t) - \nabla_{x_t} \log q(x_t \mid x_0) \|^2 \right] \]
の形を取る。MDLM の目的関数 (式 1) は、これと 構造的に同型 である。
| 連続拡散(DSM) | MDLM |
|---|---|
| L2 損失(スコア \(s_\theta\) 回帰) | masked cross-entropy(\(x_0\) 予測) |
| ノイズ強度依存の重み \(w(t)\) | 時刻依存の重み \(1/t\) |
| forward: ガウシアンノイズ付与 | forward: 確率 \(t\) で [MASK] 化 |
| reverse: SDE / ODE 積分 | reverse: 離散時刻 unmask |
両者とも「\(x_0\)-prediction の精神で逆過程を学習する」点で共通しており、損失の重み構造が「重み付き回帰 vs 重み付き分類」の差として現れている、と理解しておくと統一的に見える。
→ 詳細: 連続拡散と離散拡散の橋渡し
関連手法へのリンク
MDLM を起点に、次の章で派生・関連手法を扱う。
- 派生・スケール: LLaDA: 大規模 Masked DLM とサンプリング — MDLM の定式化を 8B パラメータにスケールし、実用的なサンプラを提示
- 別系統の離散拡散: D3PM と SEDD: 離散拡散の別の選択肢 — absorbing 以外の遷移行列、score-based な定式化
- サンプラの源流: MaskGIT: Confidence-based Iterative Unmasking の源流 — 画像生成における confidence-based unmasking
