深層学習の学習時に中間活性化値をすべてメモリに保持する代わりに、一部のレイヤーの出力のみを保存し、逆伝播時に必要な活性化値を再計算する手法。GPUメモリ使用量を大幅に削減し、より大きなモデルやバッチサイズでの学習を可能にする。
Activation Checkpointing(勾配チェックポインティング、Gradient Checkpointing とも呼ばれる)は、ニューラルネットワークの学習時における GPU メモリ消費を削減するための手法です。通常、順伝播(Forward Pass)で計算した各レイヤーの中間活性化値はすべてメモリに保持され、逆伝播(Backward Pass)で勾配計算に使用されます。大規模モデルではこの活性化値のメモリ消費がモデルパラメータ自体の数倍に達するため、学習のボトルネックとなります。
Transformer ベースの LLM では、活性化値のメモリ消費は以下の要因で決まります。
| パラメータ | メモリへの影響 |
|---|---|
| バッチサイズ (B) | 線形に増加 |
| シーケンス長 (S) | 線形〜二乗に増加(Self-Attention は S² に比例) |
| 隠れ層の次元 (H) | 線形に増加 |
| レイヤー数 (L) | 線形に増加 |
例えば、Llama 2 7B(32層、H=4096)をシーケンス長 4096、バッチサイズ 4 で学習する場合、活性化値だけで約 60GB のメモリを消費します。モデルパラメータ(FP16 で約 14GB)とオプティマイザ状態(Adam で約 56GB)を加えると、合計 130GB 以上が必要です。
Activation Checkpointing は「メモリと計算のトレードオフ」を実現します。
チェックポイントなし(標準): 全 L 層の活性化値を保持。メモリ消費 O(L)、再計算なし。
全レイヤーチェックポイント: 特定の間隔(例えば √L 層ごと)にチェックポイントを設定。チェックポイント間の活性化値は破棄し、逆伝播時に再計算。メモリ消費 O(√L)、順伝播の約 33% を再計算。
セグメント単位チェックポイント: Transformer ブロック単位でチェックポイントを設定。各ブロックの入力のみを保存し、ブロック内部の活性化値は逆伝播時に再計算。
| 方式 | メモリ削減率 | 計算オーバーヘッド | 実装の容易さ |
|---|---|---|---|
| チェックポイントなし | 0% | 0% | - |
| 全レイヤー均等 | 60〜70% | 30〜35% | 容易 |
| セグメント(ブロック)単位 | 50〜60% | 25〜30% | 中程度 |
| 選択的チェックポイント | 40〜80% | 15〜25% | 要分析 |
選択的チェックポイント: メモリ消費量が大きいレイヤー(Self-Attention の QKV 行列など)のみチェックポイントし、メモリ消費が小さいレイヤー(LayerNorm など)は保持する方式。分析コストがかかるが、最適なメモリ/計算トレードオフを実現します。
PyTorch では torch.utils.checkpoint.checkpoint 関数で簡単に利用できます。
Hugging Face Transformers では model.gradient_checkpointing_enable() を呼ぶだけで有効化されます。DeepSpeed や FSDP などの分散学習フレームワークでも設定一つで有効化可能です。
Activation Checkpointing は単体でも効果的ですが、他の手法と組み合わせることでさらなるメモリ削減が可能です。
これらを全て組み合わせることで、RTX 4090(24GB VRAM)でも 13B パラメータモデルのフルファインチューニングが可能になります。
A1: 推論時は逆伝播が不要なため、活性化値を保持する必要がなく、チェックポインティングは無意味です。推論のメモリ最適化には KV Cache の量子化やページングなど別の手法を使います。
A2: 理論上は順伝播の再計算分だけ遅くなりますが、実際にはメモリ削減によりバッチサイズを大きくできるため、スループット(トークン/秒)はほぼ維持できるケースが多いです。A100 80GB でチェックポイントなし・バッチ4 vs チェックポイントあり・バッチ8 なら後者の方が高速です。
A3: LoRA は学習パラメータを大幅に削減しますが、順伝播の活性化値は全モデルで計算されるためメモリ消費は残ります。7B モデルの LoRA ファインチューニングでもチェックポイントを有効化すると VRAM 使用量が 4〜6GB 削減されます。