Medusa(メデューサ)は、ターゲットLLMの最終隠れ層に複数の追加デコーディングヘッドを取り付け、各ヘッドが将来の異なる位置のトークンを同時に予測することで、外部ドラフトモデルなしに投機的デコーディングを実現する手法である。Medusa-1は典型的採択(typical acceptance)による近似検証、Medusa-2は修正棄却サンプリングによる厳密検証をサポートし、ツリーベースの候補構造と組み合わせることで2〜3倍のスピードアップを実現する。
Medusa(Cai et al., 2024、「Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads」)は、投機的デコーディングにおける「外部ドラフトモデルが必要」という制約を根本的に解消するアプローチである。従来の投機的デコーディングでは、ターゲットモデルとは別に小型のドラフトモデルをGPUメモリにロードする必要があったが、Medusaはターゲットモデル自体に複数の予測ヘッドを追加するだけでドラフト機能を統合する。
Medusaのアーキテクチャは以下の構成要素からなる。
| コンポーネント | 説明 | パラメータ数 |
|---|---|---|
| ベースモデル(Backbone) | 元のLLM(Llama、Vicunaなど)のTransformerブロック全体 | 元のモデルと同一 |
| オリジナルLMヘッド(Head 0) | 次トークン(位置t+1)を予測する元のヘッド | hidden_dim × vocab_size |
| Medusa Head 1 | 位置t+2のトークンを予測 | hidden_dim × vocab_size + α |
| Medusa Head 2 | 位置t+3のトークンを予測 | hidden_dim × vocab_size + α |
| Medusa Head k | 位置t+k+1のトークンを予測 | hidden_dim × vocab_size + α |
各Medusa Headは、ベースモデルの最終隠れ層の出力を入力とし、1〜2層のFeed-Forward Network(FFN)とReLU活性化関数を経て語彙サイズの出力を生成する。αはFFN層の追加パラメータ数であり、通常はhidden_dim × hidden_dim程度である。Medusa Headの追加パラメータはモデル全体の0.5〜2%程度と非常に小さい。
動作フローは以下の通りである。
Medusaには検証方式の異なる2つのバージョンが存在する。
Medusa-1は、標準的な修正棄却サンプリングではなく、「典型的採択(Typical Acceptance)」と呼ばれる近似的な検証方式を使用する。典型的採択では、ターゲットモデルの出力確率が一定の閾値を超えるトークンを採択する。
具体的には、位置iの候補トークンxiに対して、ターゲットモデルの確率p(xi)がエントロピーベースの閾値ε以上であれば採択する。この方式は実装が単純で高速だが、出力分布がターゲットモデルと厳密には一致しないため、ごくわずかな品質差異が生じる可能性がある。
Medusa-2は、標準的な投機的デコーディングと同じ修正棄却サンプリングを適用する。これにより、出力分布がターゲットモデルと数学的に同一であることが保証される。ただし、Medusa HeadはResidue Connection(残差接続)を含む改良版アーキテクチャを使用しており、Medusa-1よりも高い予測精度を実現する。
| 特性 | Medusa-1 | Medusa-2 |
|---|---|---|
| 検証方式 | 典型的採択(近似) | 修正棄却サンプリング(厳密) |
| 出力分布 | ターゲットと近似的に一致 | ターゲットと厳密に一致 |
| Head構造 | 単純FFN | Residual FFN(残差接続付き) |
| 訓練方式 | ベースモデル固定、Head のみ訓練 | ベースモデル固定、Head のみ訓練 |
| スピードアップ |
| 2〜2.5倍 |
| 2〜3倍 |
| 適用シーン | 品質よりスループット優先 | 厳密な品質保証が必要 |
Medusa Headの訓練は、ベースモデルのパラメータを完全に固定(freeze)した状態で、追加ヘッドのパラメータのみを更新する。これにより、元のモデルの性能を一切損なわずにMedusa機能を付加できる。
訓練データとしては、ベースモデルの訓練データのサブセット、またはShareGPTなどの対話データセットが使用される。損失関数は、各Medusa Headが対応する将来位置のトークンを正しく予測するクロスエントロピー損失の合計である。
| 訓練パラメータ | 推奨値 | 説明 |
|---|---|---|
| エポック数 | 1〜3 | 少数エポックで十分な精度が得られる |
| 学習率 | 1e-3〜3e-3 | AdamW optimizer |
| バッチサイズ | 16〜64 | GPUメモリに応じて調整 |
| 訓練データ量 | 10K〜100K サンプル | ShareGPT程度のデータで十分 |
| GPU時間 | 数時間〜1日 | A100 1枚で7Bモデル用ヘッドを訓練可能 |
| Head数 | 3〜5 | 5以上はスピードアップの飽和が見られる |
訓練のポイントとして、各ヘッドは独立に訓練することも、全ヘッドを同時に訓練することも可能だが、同時訓練の方が全体的なバランスが良好になる傾向がある。また、Head 1(1つ先のトークン予測)の精度が最も高く、Head kが大きくなるほど予測が困難になるため、損失関数にヘッド番号に応じた重み付けを行うことも検討に値する。
Medusaでは、各ヘッドが出力する上位候補の組み合わせからツリー構造の候補集合を構築する。例えば、Head 0のtop-3、Head 1のtop-3、Head 2のtop-2を組み合わせると、最大3×3×2 = 18パスのツリーが構成される。
ツリーの構築には、「Medusa Tree」と呼ばれる事前定義されたトポロジーが使用される。Cai et al.はいくつかの標準的なMedusa Tree構成を提案しており、以下が代表的である。
| ツリー名 | ノード数 | 構造 | 期待採択数 |
|---|---|---|---|
| mc_sim_7b_63 | 63 | 幅広、浅い | 3.0〜3.5 |
| mc_sim_7b_127 | 127 | 幅広、中深度 | 3.5〜4.0 |
| mc_sim_7b_255 | 255 | 最大 | 4.0〜4.5 |
最適なツリーサイズはGPUメモリとフォワードパスのコストに依存する。ノード数が多いほど採択チャンスは増えるが、ツリーアテンションの計算コストも増加するため、デプロイ環境ごとにプロファイリングして決定することが推奨される。
EAGLE(Li et al., 2024、「EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty」)は、Medusaの後継として位置づけられる手法であり、トークンレベルではなく特徴量レベルでの自己回帰予測を行う。
Medusaでは各ヘッドが独立に将来トークンを予測するが、EAGLEではターゲットモデルの隠れ状態(特徴量)の次ステップを予測する軽量なオートリグレッシブネットワークを追加する。この特徴量レベルの予測は、トークンレベルの予測よりも平滑で学習が容易であるため、少ないパラメータと訓練データでより高い採択率を達成できる。
| 比較項目 | Medusa | EAGLE |
|---|---|---|
| 予測対象 | トークン(離散) | 特徴量(連続) |
| ヘッド構造 | 独立FFN | 自己回帰FFN |
| 採択率 | 0.6〜0.75 | 0.7〜0.85 |
| スピードアップ | 2〜3倍 | 2.5〜4倍 |
| 追加パラメータ | 0.5〜2% | 1〜3% |
| 訓練コスト | 数時間(A100 1枚) | 数時間〜半日(A100 1枚) |
EAGLE-2ではさらに、文脈に応じたツリー構造の動的調整(context-aware tree construction)が導入され、入力に応じて最適なツリートポロジーが自動選択される。
Medusaを本番環境にデプロイする際の主な考慮事項は以下の通りである。
実用的には3〜5個が推奨される。Head数を増やすと、より遠い将来のトークンを予測できるが、遠いほど予測精度が低下するため、Head 5以降の追加による採択率向上は限定的である。Cai et al.の実験では、Head数3で約2倍、Head数5で約2.5倍のスピードアップが報告されており、5を超えると飽和傾向にある。
Transformer系のデコーダオンリーアーキテクチャであれば、原理的に全てのモデルに適用可能である。LLaMA、Mistral、Phi、Gemma等の主要モデルで実績がある。ただし、Mixture of Experts(MoE)モデルでは最終隠れ層の挙動が異なるため、Head設計に追加の調整が必要な場合がある。
新規デプロイではEAGLE(特にEAGLE-2)を推奨する。EAGLEはMedusaよりも高い採択率を達成し、特に長文生成タスクでの優位性が顕著である。ただし、EAGLEの訓練はMedusaよりやや複雑であり、既存のMedusaヘッドが利用可能な環境ではMedusaのまま運用する選択肢も合理的である。vLLMでは両方がサポートされているため、実際のワークロードでベンチマークを取って判断するのが望ましい。