Layer-wise Adaptive Moments optimizer for Batch training の略。大バッチ分散学習向けに設計されたオプティマイザで、レイヤーごとにパラメータのノルムと更新量のノルムの比率で学習率を適応的に調整する。Google が 2019 年に提案し、BERT-Large を TPUv3 1,024 基・バッチサイズ 64K で 76 分で学習する記録を達成した。
LAMB(Layer-wise Adaptive Moments optimizer for Batch training)は、2019 年に Google の Yang You らが論文「Large Batch Optimization for Deep Learning: Training BERT in 76 Minutes」で提案したオプティマイザである。通常、バッチサイズを増大させると学習の品質が低下する(large batch problem)が、LAMB はレイヤーごとの適応的学習率スケーリングによりこの問題を解決した。
分散学習では GPU 数を増やしてバッチサイズを拡大し、学習時間を短縮する。しかし、バッチサイズが大きくなると以下の問題が発生する:
LAMB はこれらの問題を、レイヤーごとの信頼比率(trust ratio)で学習率を動的に調整することで克服する。
LAMB は LARS(Layer-wise Adaptive Rate Scaling)のアイデアを Adam に統合した手法:
信頼比率 φ(θ) がキーで、更新量 r がパラメータ θ に対して相対的に大きすぎる場合は抑制し、小さすぎる場合は増幅する。これにより、異なるスケールのレイヤー間で均一な相対的更新量を維持する。
| バッチサイズ | 使用 GPU (TPUv3) | 学習時間 | オプティマイザ | 性能 (SQuAD F1) |
|---|---|---|---|---|
| 256 | 8 | 〜72 時間 | Adam | 90.9 |
| 8K | 256 | 〜5 時間 | Adam | 89.7 (低下) |
| 8K | 256 | 〜4 時間 | LAMB | 90.8 |
| 32K | 1,024 | 〜100 分 | LAMB | 90.6 |
| 64K | 1,024 | 〜76 分 | LAMB | 90.4 |
Adam ではバッチサイズ 8K で性能が 1.2 ポイント低下するが、LAMB ではほぼ維持。バッチサイズ 64K(通常の 256 倍)でも 0.5 ポイントの低下に抑えた。
LAMB の前身である LARS(Layer-wise Adaptive Rate Scaling)は SGD+Momentum ベースで、主に CNN の大バッチ学習に使われていた:
| 特性 | LARS | LAMB |
|---|---|---|
| ベースオプティマイザ | SGD+Momentum | Adam |
| 対象モデル | CNN (ResNet等) | Transformer (BERT等) |
| 適応的学習率 | なし | あり(Adam由来) |
| 信頼比率 | ||
| 最大バッチサイズ | 32K | 64K+ |
LAMB は LARS の信頼比率を Adam の適応的学習率と組み合わせることで、Transformer アーキテクチャでも大バッチ学習を可能にした。
# NVIDIA apex ライブラリの FusedLAMB
from apex.optimizers import FusedLAMB
optimizer = FusedLAMB(
model.parameters(),
lr=6e-3, # 大バッチでは高い学習率
betas=(0.9, 0.999),
weight_decay=0.01,
max_grad_norm=1.0 # 勾配クリッピング
)
# Warmup は全ステップの 1-2.5% が推奨
scheduler = WarmupLinearSchedule(
optimizer,
warmup_steps=total_steps * 0.01,
t_total=total_steps
)
DeepSpeed では deepspeed.ops.lamb.FusedLamb として CUDA 最適化版が提供されている。
LAMB が有効な条件:
限界:
A: GPU/TPU 256 基以上・バッチサイズ 8K 以上の大規模分散学習で真価を発揮する。個人やスタートアップの 8-16 GPU 環境では AdamW で十分であり、LAMB を使う利点は少ない。Google、Meta、NVIDIA など大規模クラスタを持つ組織の事前学習が主な適用先。
A: ほぼ変わらない。ファインチューニングでは通常バッチサイズ 16-256 程度で、LAMB の大バッチ適応機構が活かされない。ファインチューニングでは AdamW(lr=1e-5〜5e-5)が推奨される。
A: 大バッチ学習では勾配推定の分散が小さくなり、各ステップの更新が保守的になりすぎる。LAMB は信頼比率によるレイヤー適応とともに高い学習率(1e-3〜6e-3)を使うことで、保守的な更新を補正し学習速度を維持する。Linear Scaling Rule(バッチ N 倍→学習率 N 倍)を超えた適応的スケーリングが LAMB の本質的な貢献。