分散学習の最も基本的な手法。モデル全体を各GPUに複製し、学習データのミニバッチを各GPUに分配して並列に前方伝播・後方伝播を実行する。各GPUで計算した勾配をAllReduceで集約・平均してパラメータを同期更新する。
PC構成ビルダーで最適なパーツを選択
データ並列(Data Parallelism, DP)は、分散学習の中で最も基本的かつ広く使われる手法である。モデルの完全なコピーを各GPU(ワーカー)に配置し、学習データセットをワーカー間で分割して並列に学習を行う。各ワーカーが計算した勾配をAllReduceで集約・平均し、全ワーカーのモデルパラメータを同期的に更新する。
PyTorchの標準データ並列実装であるDDPは以下の最適化を行う。
| 機能 | 説明 | デフォルト値 |
|---|---|---|
| 勾配バケット化 | 小さな勾配テンソルをバケットにまとめてAllReduce | 25MB |
| 計算-通信オーバーラップ | backward計算と並行して勾配をAllReduce | 有効 |
| 未使用パラメータ検出 | forward未使用のパラメータの勾配をスキップ | False |
| 勾配累積 | N回のforward/backwardを勾配累積してからAllReduce | 1 |
| バックエンド | 通信バックエンド選択 | nccl (GPU) |
# PyTorch DDPの基本的な使い方(概念例)
# model = DistributedDataParallel(model, device_ids=[local_rank])
# optimizer.zero_grad()
# loss = model(input).sum()
# loss.backward() # AllReduceは自動実行
# optimizer.step()
MicrosoftのDeepSpeedが提唱したZeROは、データ並列の冗長性を排除してメモリ効率を大幅に改善する。
| ステージ | 分割対象 | メモリ削減 | 通信量 |
|---|---|---|---|
| ZeRO-1 | オプティマイザ状態 | 4倍 | DDP同等 |
| ZeRO-2 | + 勾配 | 8倍 | DDP同等 |
| ZeRO-3 | + パラメータ | Nリニア | 1.5倍増 |
| ZeRO-Infinity | + CPUオフロード | ほぼ無制限 | 大幅増 |
ZeRO-3では7Bモデル(FP16で14GB)の学習が、8GPUならGPUあたり約1.75GB+αのパラメータメモリで済む。オプティマイザ状態(Adamで約56GB)も8分割で7GB/GPUとなる。
MetaのFSDPはZeROの概念をPyTorchネイティブに実装したものである。
FULL_SHARD(ZeRO-3相当)、SHARD_GRAD_OP(ZeRO-2相当)、NO_SHARD(DDP相当)checkpoint_wrapperで再計算を自動適用transformer_auto_wrap_policyでTransformerブロック単位に自動シャーディングtorch.compileとの互換性が改善| GPU数 | グローバルバッチサイズ | 学習率 | スケーリング効率(理想1.0) |
|---|---|---|---|
| 1 | 32 | 1e-4 | 1.0(基準) |
| 8 | 256 | 8e-4 | 0.95-0.98 |
| 64 | 2,048 | 6.4e-3 | 0.90-0.95 |
| 256 | 8,192 | 2.56e-2 | 0.85-0.92 |
| 1,024 | 32,768 | ウォームアップ必須 | 0.75-0.85 |
大規模バッチでは学習率のスケーリングが重要。線形スケーリングルール(GPU数に比例して学習率を上げる)やLARS/LAMBオプティマイザを使用する。
Q1: データ並列とDDP(DistributedDataParallel)の違いは?
A: データ並列は概念・手法の名称。DDPはPyTorchにおけるデータ並列の実装。旧式のDataParallel(DP)はGILボトルネックで非推奨。DDPはプロセスベースでGPU間のAllReduceを効率的に行う。
Q2: ZeRO-3とFSDPはどちらを使うべきですか? A: PyTorchエコシステムに統一したいならFSDPが推奨。DeepSpeed固有機能(ZeRO-Infinity、1-bit Adam、勾配圧縮)が必要ならZeRO-3。性能はほぼ同等だが、FSDPはPyTorch公式でありtorch.compileとの統合が優れている。
Q3: データ並列だけで大規模モデルを学習できますか? A: ZeRO-3/FSDPを使えば数十億パラメータまでは可能。70Bモデルを64台のH100(ZeRO-3)で学習した実績がある。ただし405B以上ではテンソル並列・パイプライン並列との組み合わせ(3D並列)が必須。