LLMの事前学習時に1つの位置から複数の将来トークンを同時に予測するよう訓練することで、モデルの表現力を向上させると同時に、推論時に2〜4倍の高速化を実現する学習・推論統合型の並列デコーディング手法。
Multi-Token Prediction(MTP)は、Meta AIのGloeckle et al.が2024年に発表した「Better & Faster Large Language Models via Multi-token Prediction」で提案された手法である。従来のLLMが次の1トークンのみを予測するのに対し、MTPは1つの位置から将来のN個のトークンを同時に予測するようモデルを訓練する。この設計により、事前学習の段階からモデルが長期的な計画能力を獲得し、推論時にはN個の出力ヘッドを活用して並列デコーディングが可能になる。
MTPの学習アーキテクチャは以下の構成を取る。
MTPの最も注目すべき成果は、同一の計算予算(FLOPs)でNext-Token Prediction(NTP)より高い品質を達成する点である。
| モデルサイズ | 学習方式 | HellaSwag | MMLU | HumanEval | 学習FLOPs |
|---|---|---|---|---|---|
| 7B | NTP (baseline) | 78.2% | 45.3% | 28.0% | 1.0x |
| 7B | MTP (N=4) | 80.1% | 47.8% |
| 34.2% |
| 1.0x(同一) |
| 13B | NTP (baseline) | 82.4% | 52.1% | 35.4% | 1.0x |
| 13B | MTP (N=4) | 83.9% | 54.6% | 41.5% | 1.0x(同一) |
| 34B | NTP (baseline) | 85.7% | 60.2% | 42.8% | 1.0x |
| 34B | MTP (N=4) | 87.0% | 62.4% | 49.1% | 1.0x(同一) |
同一FLOPsでHumanEval(コード生成)が+6〜7%ポイント、MMLU(汎用知識)が+2〜3%ポイント向上している。特にコード生成タスクでの改善が顕著であり、MTPがコードのような構造的な長期依存関係の学習に効果的であることを示している。
MTPで事前学習されたモデルは、N個の出力ヘッドを推論時に並列デコーディングに転用できる。
2026年時点で、MTPは以下の主要モデルで採用または採用予定とされている。
| モデル/組織 | MTP採用状況 | Nの値 | 備考 |
|---|---|---|---|
| DeepSeek V3/R1 | 採用済み | N=1(auxiliary) | MTPをauxiliary lossとして使用、推論時は通常NTP |
| Meta Llama 4 | 採用済み | N=4 | 事前学習からMTP組み込み、並列推論ネイティブ対応 |
| Qwen 3 | 未採用 | - | NTP + Speculative Decodingを採用 |
| Mistral Large 2 | 部分採用 | N=2 | 一部チェックポイントで実験的に適用 |
MTPはSpeculative Decodingの「自己完結型」バリアントと捉えることができる。
| 比較軸 | MTP (Self-Speculative) | Standard Speculative |
|---|---|---|
| ドラフト生成 | 自身のN個のヘッド | 外部小型モデル |
| 追加メモリ | +10〜20% | +20〜40% |
| 追加学習 | 事前学習時に組み込み | 不要 |
| 既存モデル適用 | 不可(再学習必要) | 可能 |
| 品質 | NTPより高い(MTP効果) | NTPと同一(ロスレス) |
| 高速化倍率 | 2.0〜3.5x | 2.0〜3.5x |
MTPの最大の制約は既存モデルに後付けできない点である。事前学習の段階からN個のヘッドを組み込む必要があるため、既にリリースされたモデル(Llama 3.x等)には適用できない。この点で、Medusa/EAGLEやSpeculative Decodingとは相互補完的な関係にある。
Q1: MTPのNは何に設定すべきですか? A: Meta AIの実験ではN=4が品質・効率のバランスが最も良い。N=2でも十分な品質向上が得られるが、N=8以上では追加ヘッドの計算コストが品質向上を上回る。コード生成特化ならN=4〜6、汎用ならN=4が推奨。
Q2: MTPで事前学習されたモデルは通常のNTPとしても使えますか? A: 使える。推論時に最初のヘッド(next-token head)のみを使えば、通常のNTPモデルと同様に動作する。MTP学習自体がモデルの表現力を向上させるため、NTPモードでも品質はベースラインを上回る。
Q3: MTPの学習コストはどの程度増加しますか? A: 学習時のFLOPsはN個のヘッドの追加分として5〜15%増加する。ただし、同一品質を達成するために必要なトークン数がMTPでは10〜20%少なくて済むため、総学習コストは概ね同等またはやや低い。
Q4: MTPモデルに後からMedusa/EAGLEを追加できますか? A: 可能。MTPのN個のヘッドに加えてMedusa/EAGLEヘッドを追加学習することで、さらなる高速化が期待できる。ただし、MTP自体が十分な並列デコーディング能力を持つため、追加効果は限定的(+10〜20%程度)である。