Stanford大学のTri Daoらが提案したAttention計算のIO最適化アルゴリズム。GPUのHBM(高帯域メモリ)とSRAM(オンチップメモリ)間のデータ転送を最小化するタイリング手法により、標準的なAttentionと数学的に等価な結果をメモリ使用量O(n)・最大3倍の高速化で実現する。
FlashAttention は、2022年にStanford大学の Tri Dao らが発表したAttention計算の高速化手法です。Transformerの標準的なSelf-Attention実装では、n×n のAttention行列全体をGPUのHBM(High Bandwidth Memory)に書き込む必要があり、これがメモリと速度の両方のボトルネックになっていました。FlashAttentionはこの行列を明示的に実体化せず、タイリング(ブロック分割)とオンライン正規化によってSRAM内で計算を完結させます。
| メモリ種別 | 容量(A100) | 帯域幅 | 用途 |
|---|---|---|---|
| HBM(VRAM) | 80 GB | 2 TB/s | モデル重み、KVキャッシュ |
| L2キャッシュ | 40 MB | ~5 TB/s | 中間バッファ |
| SRAM(共有メモリ) | 192 KB/SM | ~19 TB/s | レジスタ直近の高速メモリ |
標準Attention実装は「QK^Tの全n×n行列をHBMに書き込み→Softmax→V乗算」と3回のHBMアクセスが必要です。FlashAttentionはこれを1パスに統合し、ブロック単位でSRAM内に保持したまま計算を完結させることで、HBMアクセス回数を劇的に削減します。
Attention行列をB_q × B_k サイズのブロックに分割し、1ブロックずつSRAMに読み込んで処理します。各ブロックの計算結果はオンライン(逐次的に)正規化することで、全体のSoftmaxと数学的に同一の結果を得られます。
| バージョン |
|---|
| 発表年 |
|---|
| 主要改善 |
|---|
| 速度(vs標準) |
|---|
| 対応GPU |
|---|
| FlashAttention v1 | 2022 | タイリング+リマテリアライゼーション | 2-4x | A100, H100 |
| FlashAttention-2 | 2023 | ワープ間並列化、非因果マスク対応 | 5-9x | A100, H100, RTX 4090 |
| FlashAttention-3 | 2024 | FP8対応、非同期ワープスケジューリング | 1.5-2x vs FA2 | H100, H200 |
シーケンス長8,192、ヘッド数32、d_k=128の場合:
| 実装 | Attention行列メモリ | 合計メモリ | 備考 |
|---|---|---|---|
| 標準(PyTorch naive) | 2 GB/ヘッド | 64 GB | n²のAttention行列を実体化 |
| FlashAttention | ~数MB | ~数MB | 行列を実体化しない |
FlashAttentionはAttention行列をメモリに保持しないため、シーケンス長に対してO(n)のメモリ使用量で済みます。これにより、同一GPUメモリで扱えるコンテキスト長が大幅に拡大しました。
| フレームワーク | 統合方法 | 備考 |
|---|---|---|
| PyTorch 2.0+ | torch.nn.functional.scaled_dot_product_attention | 自動的にFlashAttentionカーネルを選択 |
| Hugging Face Transformers | model.config.attn_implementation = "flash_attention_2" | 明示的に有効化 |
| vLLM | デフォルトで有効 | PagedAttentionと併用 |
| TensorRT-LLM | 自動最適化 | H100向けFA3も統合 |
| llama.cpp | Metal/CUDA対応のFA実装内蔵 | ローカルLLM推論 |
A1: 数学的には同一の計算を行うため、理論上は出力が変わりません。ただし浮動小数点演算の順序が異なるため、FP16精度では末尾数桁に微小な差が生じる場合がありますが、モデルの品質に影響するレベルではありません。
A2: FlashAttention-2はAmpere世代(RTX 3090含む)で動作しますが、一部の最適化(FP8等)はHopper世代(H100以降)でのみ有効です。RTX 3090でも標準Attention比2〜4倍程度の高速化が期待できます。
A3: FlashAttentionはAttention計算自体のIO最適化(タイリング、SRAM活用)、PagedAttentionはKVキャッシュのメモリ管理最適化(仮想メモリ的なページング)です。両者は異なるレイヤーの最適化であり、vLLMなどでは併用されています。