PyTorch公式の完全シャード化データ並列(Fully Sharded Data Parallel)フレームワーク。DeepSpeed ZeROと同様にパラメータ・勾配・オプティマイザ状態をGPU間で分割し、大規模モデルの学習を効率化する。PyTorch 2.0以降で安定版として提供される。
PyTorch FSDP(Fully Sharded Data Parallel)は、PyTorch公式が提供する分散学習フレームワークである。Facebookが開発したFairScale FSDPをPTorchコアに統合したもので、DeepSpeed ZeRO Stage 3と同様の完全シャード化データ並列を実現する。PyTorch 2.0以降で安定版(Stable)となり、Meta社のLlama 2/3の学習にも使用された実績を持つ。
FSDPの基本原理は、モデルパラメータをGPU間で分割(シャード)し、Forward/Backward計算時に必要なパラメータをAllGatherで一時収集、計算後に解放するというものである。これにより、各GPUのメモリ使用量をGPU数Nに比例して削減できる。
PyTorch 2.4以降、FSDP2(torch.distributed.fsdp2)が導入され、内部実装が大幅に刷新された。
| 機能 | FSDP1 | FSDP2 |
|---|---|---|
| シャーディング単位 | FlatParameter(結合テンソル) | DTensor(個別パラメータ) |
| 混合精度 | 制限あり(FlatParameter単位) | 柔軟(パラメータ単位) |
| テンソル並列統合 | 非公式 | 公式サポート(TP+FSDP) |
| チェックポイント | 特殊形式 | 標準PyTorchフォーマット |
| メモリ効率 | 良好 | より高い(パディング不要) |
| 通信重畳 | 手動設定 | 自動最適化 |
| API | Module単位のラップ |
| torch.compile互換 |
FSDP1ではモデルの複数パラメータを1つのFlatParameterに結合してシャーディングしていたが、FSDP2ではDTensor抽象化により個別パラメータレベルでシャーディングする。これにより、パディングの無駄が排除され、チェックポイントの互換性も向上した。
FSDPは3つのシャーディング戦略を提供する。
| 戦略 | メモリ削減 | 通信量 | 用途 |
|---|---|---|---|
| FULL_SHARD | 最大(N分割) | AllGather + ReduceScatter | 大規模モデル(デフォルト) |
| SHARD_GRAD_OP | 中(勾配+オプティマイザ) | ReduceScatter | 中規模モデル |
| NO_SHARD | なし(DDP相当) | AllReduce | 小規模モデル・デバッグ |
| HYBRID_SHARD | ノード内FULL + ノード間DDP | ノード内AllGather + ノード間AllReduce | マルチノード最適化 |
HYBRID_SHARDは特に重要で、ノード内の高速NVLink通信でFULL_SHARDを行い、ノード間はDDP的なAllReduceで勾配同期する。これにより、ノード間通信量を大幅に削減しつつ、ノード内メモリ効率を最大化できる。Meta社のLlama 3学習でもこの戦略が採用された。
PyTorch FSDPの基本的な使用パターンを示す。
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy, MixedPrecision
# Mixed Precision設定
mp_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
)
# FSDPラップ
model = FSDP(
model,
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
mixed_precision=mp_policy,
auto_wrap_policy=size_based_auto_wrap_policy,
device_id=torch.cuda.current_device(),
limit_all_gathers=True,
)
auto_wrap_policyにより、一定サイズ以上のモジュール(通常はTransformerBlock単位)を自動的にFSDPユニットとしてラップする。ラップ粒度はメモリ効率と通信効率のトレードオフであり、細かすぎると通信回数が増加し、粗すぎるとメモリピークが上がる。
FSDPとDeepSpeed ZeROは同じ原理(パラメータシャーディング)に基づくが、実装と機能に違いがある。
| 比較項目 | FSDP | DeepSpeed ZeRO |
|---|---|---|
| 段階的最適化 | FULL_SHARD固定 | Stage 1/2/3選択可 |
| CPUオフロード | 基本対応 | 高度な最適化(Infinity) |
| NVMeオフロード | 未対応 | ZeRO-Infinity対応 |
| PyTorch統合 | ネイティブ | 外部ライブラリ |
| torch.compile | 対応(FSDP2) | 制限あり |
| Megatron統合 | FSDP2で公式対応 | Megatron-DeepSpeed |
| HuggingFace統合 | Accelerate経由 | Trainer直接対応 |
| MoE対応 | 限定的 | DeepSpeed-MoE |
PyTorchエコシステムとの統合ではFSDPが優位で、torch.compileによるカーネル融合やtorch.distributed.checkpointとの互換性が高い。一方、CPUオフロードの最適化やMoE対応ではDeepSpeedが先行している。
FSDPと組み合わせる重要なメモリ最適化技術として、活性値チェックポイント(Activation Checkpointing)がある。Forward時の中間活性値を保存せず、Backward時に再計算する手法で、メモリ使用量を大幅に削減できる。
Transformerモデルでは、各ブロックの活性値がシーケンス長×バッチサイズ×隠れ次元に比例するため、32層モデルでは活性値だけで数十GBに達する。チェックポイントをTransformerブロック単位で適用することで、活性値メモリを1/32に削減できる(計算量は約33%増加)。
モデルサイズが単一GPUメモリの60%以下に収まる場合はDDPが高速である。それ以上のモデル、またはバッチサイズを大きくしたい場合はFSDPを推奨する。7Bモデルでも、BF16学習でバッチサイズを増やしたい場合はFSDPのHYBRID_SHARDが有効である。
PyTorch 2.5以降の新規プロジェクトではFSDP2を推奨する。既存のFSDP1コードは引き続き動作するが、torch.compile統合やDTensorベースのチェックポイントなど、FSDP2固有の機能が増えている。HuggingFace Accelerateも FSDP2を段階的にサポート中である。
パラメータ数P、GPU数N、BF16学習の場合: パラメータメモリ ≈ 2P/N バイト、勾配メモリ ≈ 2P/N バイト、オプティマイザメモリ ≈ 12P/N バイト(Adam)。これに活性値メモリ(シーケンス長・バッチサイズ依存)を加算する。活性値チェックポイント適用時はレイヤー数で除算する。