大規模モデルのパラメータを複数のデバイス(GPU/TPU)に分割して配置し、単一デバイスのメモリ容量を超えるモデルの学習・推論を可能にする分散処理技術。
Model Sharding(モデルシャーディング)は、大規模言語モデル(LLM)のパラメータを複数のアクセラレータに分割配置する技術である。GPT-4 や Llama 3 405B のように数千億パラメータを持つモデルは、単一 GPU のメモリ(80GB〜192GB)では収まりきらないため、シャーディングが不可欠となる。
現代の LLM は急速にスケールしている。70B パラメータモデルを FP16 で保持するだけで約 140GB の VRAM が必要であり、これに勾配(gradient)やオプティマイザ状態(optimizer state)を加えると、単一 GPU では到底収まらない。Model Sharding はこの物理的制約を突破し、複数デバイスのメモリを仮想的に統合する。
| モデル規模 | FP16 パラメータメモリ | 学習時総メモリ(目安) | 必要 GPU 数(H100 80GB) |
|---|---|---|---|
| 7B | 14 GB | 56 GB | 1 |
| 70B | 140 GB | 560 GB | 8 |
| 405B | 810 GB | 3,240 GB | 48+ |
Model Sharding にはいくつかの代表的なアプローチがある。
個々の演算テンソル(重み行列)を列方向または行方向に分割し、複数 GPU で並列に計算する。Megatron-LM が提唱した手法で、Transformer の Self-Attention 層と MLP 層に適用される。ノード内の高速 NVLink 接続を前提とするため、通常 8 GPU 以内で使用される。
モデルのレイヤーを段階的に異なる GPU に配置し、マイクロバッチを流す方式。ノード間通信量がテンソル並列より少ないため、マルチノード構成に適する。ただし、パイプラインのバブル(空き時間)が生じるため、1F1B スケジューリングや Interleaved Pipeline で最適化する。
モデル全体を各 GPU に複製し、データを分割して並列学習する。DeepSpeed ZeRO や PyTorch FSDP では、パラメータ・勾配・オプティマイザ状態をシャーディングすることで、メモリ効率を大幅に改善する。
テンソル並列 × パイプライン並列 × データ並列を組み合わせた最も強力な構成。Megatron-DeepSpeed や NeMo Framework が代表的な実装で、数千 GPU クラスタでのプレトレーニングに使用される。
| フレームワーク | テンソル並列 | パイプライン並列 | ZeRO/FSDP | 主な用途 |
|---|---|---|---|---|
| Megatron-LM | ○ | ○ | ○ | プレトレーニング |
| DeepSpeed | △(推論時) | ○ | ○(ZeRO-1/2/3) | 学習・推論 |
| PyTorch FSDP | △(DTensor) | △ | ○ | 学習 |
| NeMo Framework | ○ | ○ | ○ | エンドツーエンド |
| vLLM | ○ | ○ |
学習時だけでなく、推論時にも Model Sharding は重要である。vLLM や TensorRT-LLM はテンソル並列を用いて複数 GPU に推論負荷を分散し、レイテンシを削減する。70B モデルを 4 GPU でテンソル並列すると、単一 GPU 比で約 3.5 倍のスループットが得られる。
シャーディングでは GPU 間通信がボトルネックになる。AllReduce、AllGather、ReduceScatter などの集団通信(Collective Communication)を NCCL が最適化する。NVLink(900 GB/s)とInfiniBand(400 Gbps)の帯域差がノード内外の並列戦略選択に直結する。
A1: 広義には同じ概念を指す。Model Parallelism はモデルを複数デバイスに分割する総称で、Model Sharding はその中でも特にパラメータの分割配置に焦点を当てた用語である。実務上はほぼ同義で使われる。
A2: 推論のみなら単一 GPU で十分な場合が多い。ただし学習時はオプティマイザ状態でメモリが膨らむため、7B でもフルパラメータファインチューニングには FSDP や ZeRO が有効である。QLoRA のようなメモリ効率の良い手法を使えば単一 GPU でも可能。
A3: テンソル並列度(TP)とパイプライン並列度(PP)のバランスが最重要。TP はノード内 GPU 数(通常 4 または 8)に合わせ、PP はノード数に合わせるのが基本。DP(データ並列度)は残りの GPU で自動決定される。
| - |
| 推論特化 |