RAFT(Retrieval Augmented Fine-Tuning)の訓練手法は、oracle文書とdistractor文書を混在させたコンテキストでLLMをファインチューニングし、Chain-of-Thought形式で根拠付き回答を生成させる。P比率(oracle含有率)の制御とCoT回答生成がRAFT特有の訓練設計の核心である。
RAFT(Retrieval Augmented Fine-Tuning)の訓練手法は、従来のSFT(Supervised Fine-Tuning)を拡張し、RAGパイプラインに最適化された学習プロセスを実現する。標準的なSFTでは(質問, 回答)ペアで訓練するが、RAFTでは(質問, コンテキスト文書群, Chain-of-Thought回答)の三つ組で訓練を行う。
コンテキスト文書群にはoracle文書(正解情報を含む文書)とdistractor文書(無関係だが同ドメインの文書)が混在しており、モデルはこのノイズの中から必要な情報を選別して回答する能力を獲得する。この訓練設計により、実際のRAGパイプラインでretrieverが返す不完全な文書セットに対しても堅牢な応答が可能になる。
訓練の入出力形式は以下の通りである:
| 要素 | 内容 | 形式 |
|---|---|---|
| 入力:質問 | ドメイン固有のQAクエリ | 自然言語テキスト |
| 入力:文書群 | D* + D1, D2, ..., Dk | 連結テキスト(区切り付き) |
| 出力:CoT回答 | 推論過程 + 引用 + 最終回答 | 構造化テキスト |
P比率はRAFT訓練における最も重要なハイパーパラメータの一つである。P比率は、訓練データ全体のうちoracle文書がコンテキストに含まれるサンプルの割合を定義する。
P=1.0の場合、すべてのサンプルにoracle文書が含まれる。これは一見最も情報量が多い設定に見えるが、モデルが「コンテキストには常に正解がある」と学習してしまい、retrieverが失敗してoracle文書が含まれない場合にパフォーマンスが大幅に低下する。
P=0.0の場合、oracle文書が一切含まれない。これはクローズドブック試験と同等であり、RAGの利点が完全に失われる。
最適なP比率はドメインとタスクの特性に依存するが、論文の実験結果から以下の指針が得られている:
| P比率 | 特性 | 推奨シーン |
|---|---|---|
| 0.2〜0.4 | パラメトリック知識重視 | retriever精度が低いドメイン |
| 0.4〜0.6 | バランス型 | 汎用ドメインQA |
| 0.6〜0.8 | コンテキスト重視 | retriever精度が高いドメイン |
| 0.8〜1.0 | 最大コンテキスト活用 | 閉じたドメイン(全文書が高品質) |
実践的には、P=0.5〜0.7の範囲がほとんどのドメインで良好な結果を示す。開発時にはP比率を変えた複数のモデルを訓練し、バリデーションセットで最適値を選定することが推奨される。
RAFTの訓練において、回答は単純な最終回答ではなく、Chain-of-Thought(CoT)形式で生成される。CoT回答は以下の構造を持つ:
CoT形式の訓練には以下のメリットがある:
| メリット | 説明 |
|---|---|
| 根拠の明示 | 回答の根拠となる文書箇所が引用されるため、ユーザーが検証可能 |
| ハルシネーション抑制 | 推論過程を明示することで、根拠のない回答が検出しやすくなる |
| 情報抽出能力の向上 | 引用を含む訓練により、モデルが文書内の関連箇所を正確に特定する能力が向上 |
| デバッグ容易性 | 誤回答の原因が推論過程のどのステップにあるかを特定しやすい |
CoT回答の生成には、GPT-4などの高性能モデルを使用して教師データを作成する方法が一般的である。具体的には、oracle文書と質問を高性能モデルに与え、引用付きの詳細な回答を生成させ、これを訓練データとして使用する。
RAFTの損失関数は、標準的な言語モデリングの交差エントロピー損失をベースとするが、CoT回答の構造を考慮した設計が重要である。
基本的な損失関数は以下の形式をとる:
L = -Σ log P(ti | t1, ..., ti-1, Q, C)
ここで、tiは回答トークン、Qは質問、Cはコンテキスト文書群である。
重要な設計上の考慮点として、以下がある:
引用部分への重み付け:CoT回答の中で文書を引用している部分に対して、損失の重みを増加させることで、文書参照能力の学習を促進できる。ただし、過度な重み付けはコピー的な回答を生成する傾向があるため、適度な調整が必要である。
Distractor無視の学習:oracle文書を含まないサンプル(1-P比率の割合)では、モデルはdistractor文書を無視してパラメトリック知識で回答する必要がある。この場合の損失は、distractor文書が提示されていても無関係であることを示す推論過程を含むCoT回答に対して計算される。
正則化:RAFTでは、ベースモデルの汎化能力を維持するための正則化が重要である。LoRA/QLoRAなどのパラメータ効率的手法を使用する場合、ランクの選択が正則化の役割も果たす。典型的にはランク16〜64の範囲が推奨される。
| 損失コンポーネント | 重み(目安) | 目的 |
|---|---|---|
| CoT全体の交差エントロピー | 1.0 | 基本的な回答生成能力 |
| 引用部分の追加重み | 0.1〜0.5 | 文書参照能力の強化 |
| 最終回答部分の追加重み | 0.2〜0.3 | 回答精度の向上 |
| KL正則化(オプション) | 0.01〜0.1 | ベースモデルからの乖離防止 |
RAFT訓練を実際に実装する際の標準的な手順は以下の通りである。
ステップ1:ベースモデルの選定 RAFTはアーキテクチャ非依存だが、実用上はLlama 3系、Mistral、Qwen 2.5などの7B〜13Bモデルが訓練コストと性能のバランスが良い。大規模モデル(70B以上)はベースラインが高いためRAFTの追加効果が相対的に小さくなる傾向がある。
ステップ2:訓練フレームワークの準備 Hugging Face Transformers + TRL(Transformer Reinforcement Learning)ライブラリ、またはAxolotlなどのファインチューニングフレームワークを使用する。LoRA/QLoRAを適用する場合はPEFTライブラリも必要となる。
ステップ3:データローダーの実装 各バッチでP比率に基づいてoracle文書の含有/非含有を確率的に決定するデータローダーを実装する。文書の連結順序はランダム化し、モデルが位置情報に依存しないようにする。
ステップ4:訓練ループの実行 標準的なSFTと同様の訓練ループを実行する。バッチサイズ、学習率、ウォームアップステップ、勾配蓄積ステップなどのハイパーパラメータは標準的なSFTの知見が適用できる。
ステップ5:評価と反復 バリデーションセットでの性能を定期的に評価し、P比率やdistractor数などのRAFT固有パラメータを調整する。
| ハイパーパラメータ | 推奨値 | 備考 |
|---|---|---|
| 学習率 | 1e-5〜5e-5 | LoRA使用時は2e-4〜5e-4 |
| バッチサイズ | 4〜16 | GPU メモリに応じて調整 |
| エポック数 | 3〜5 | 過学習監視必須 |
| LoRAランク | 16〜64 | 大きいほど表現力↑、過学習リスク↑ |
| 最大系列長 | 2048〜4096 | コンテキスト文書数に応じて |
| ウォームアップ比率 | 0.03〜0.1 | 標準的なSFTと同等 |
静的なP比率で十分な性能が得られることが多いが、カリキュラム学習的なアプローチとして、訓練初期にはP比率を高く(0.8程度)設定してoracle文書からの情報抽出を重点的に学習させ、後半にかけてP比率を下げて(0.4程度)ノイズ耐性を強化する手法も研究されている。ただし、この動的P比率の効果は静的最適P比率と比較して劇的な差を生むわけではなく、まずは静的P比率での最適化を優先すべきである。
CoT回答の教師データは、GPT-4やClaude等の高性能モデルを使用して生成するのが最も効率的である。具体的には、oracle文書と質問を高性能モデルに与え、「文書から関連箇所を引用しながら段階的に推論し、最終回答を導出してください」というプロンプトで回答を生成させる。生成された回答は人手で品質チェックし、引用の正確性や推論の妥当性を検証することが推奨される。大規模なデータセット構築では、品質チェックの自動化(別のLLMによる検証)も有効である。
最も注意すべきはoverfitting(過学習)である。RAFTの訓練データはドメイン固有であるため、データ量が限られる場合が多く、過学習のリスクが高い。対策として、(1)LoRA/QLoRAでパラメータ更新量を制限する、(2)早期停止を適用する、(3)データ拡張(質問の言い換え、distractor文書の入れ替え)を行う、(4)ドロップアウト率を若干上げる、といった手法が有効である。また、ベースモデルの汎化能力を維持するため、一般的なQAデータを少量混合する手法も検討に値する。
モデルのフォワード/バックワードパスの計算コスト自体は標準SFTと同等である。ただし、RAFTでは入力系列長がコンテキスト文書群を含むため長くなり、結果としてGPUメモリ使用量と訓練時間が増加する。典型的には、5つの文書(各500〜1000トークン)をコンテキストに含める場合、入力系列長は3000〜6000トークンとなり、標準SFT(500〜1000トークン)の3〜6倍の計算コストがかかる。A100 80GB GPUを使用する場合、7Bモデルで1万サンプルの訓練に約4〜8時間を要する。