FlashAttentionの技術をデコード(推論)フェーズに特化して最適化した手法。長いコンテキストでの自己回帰生成において、Attentionの並列度を高めてGPU利用率を改善する。
Flash Decoding は、Tri Dao らが 2023 年に提案した、LLM のデコード(トークン生成)フェーズにおける Attention 計算を高速化する手法です。FlashAttention がプレフィル(入力処理)フェーズの高速化に大きく貢献した一方、デコードフェーズでは1トークンずつ生成するため Query が1行しかなく、GPU の演算ユニットを十分に活用できない問題がありました。Flash Decoding はこの問題をシーケンス長方向の並列化で解決します。
自己回帰生成では、1ステップで生成するトークンは1つです。つまり Query 行列は [1, head_dim] という極めて小さなテンソルになります。一方、KVキャッシュは [seq_len, head_dim] と長大です。
従来の FlashAttention はバッチサイズとヘッド数の次元で並列化しますが、デコード時はバッチサイズ1のことも多く、ヘッド数(例: 32〜128)だけでは GPU の数千コアを埋め切れません。結果として GPU utilization が 10〜30% 程度に留まることがあります。
Flash Decoding は KVキャッシュのシーケンス長次元を分割し、各分割を異なる GPU スレッドブロックで並列に処理します。
KVキャッシュを k 個のチャンクに分割し、各チャンクに対して独立に Attention を計算します。各チャンクは局所的な softmax 正規化係数(log-sum-exp)と部分出力を生成します。
各チャンクの部分結果を、log-sum-exp を用いてグローバルに正しい softmax で統合します。Online Softmax の手法を使うことで、数値的に安定した結合が可能です。
| 項目 | 通常Attention | FlashAttention | Flash Decoding |
|---|---|---|---|
| プレフィル速度 | 基準 | 2〜4倍高速 | FlashAttentionと同等 |
| デコード速度(短文) | 基準 | 微改善 | 1.5〜2倍 |
| デコード速度(長文) | 基準 | 微改善 | 最大8倍 |
| GPU utilization | 10〜30% | 20〜40% | 50〜80% |
Flash Decoding++ は Flash Decoding をさらに改良した手法です。リダクションフェーズにおける同期オーバーヘッドを削減し、部分 softmax の計算を非同期化することで追加の高速化を実現します。
コンテキスト長が長いほど Flash Decoding の効果は大きくなります。128Kトークンのコンテキストでは従来手法比で最大8倍の高速化が報告されています。これは GPT-4 Turbo や Claude の長いコンテキストウィンドウを活用するアプリケーションで特に重要です。
A1: 同じ研究グループ(Tri Dao ら)の成果ですが、最適化対象が異なります。FlashAttention はプレフィル(入力処理)フェーズのメモリ帯域幅を最適化し、Flash Decoding はデコード(生成)フェーズの並列度を最適化します。多くの推論エンジンでは両方を組み合わせて使用します。
A2: CUDA対応のNVIDIA GPUが主なターゲットです。特にH100やA100などのデータセンターGPUで効果が大きいですが、RTX 4090等のコンシューマGPUでも動作します。AMD ROCm向けの移植も進んでいます。
A3: バッチサイズが大きいとヘッド数×バッチサイズの並列度で GPU が埋まるため、Flash Decoding の追加効果は小さくなります。主に小バッチ・長コンテキストの条件で劇的な改善が見られます。