Retentive Networkの略で、Microsoftが2023年に提案したTransformer代替アーキテクチャ。並列学習・リカレント推論・チャンク並列推論の3つの計算モードを切り替え可能で、学習はTransformer同等の並列性、推論はRNN同等の効率性を実現する。
PC構成ビルダーで最適なパーツを選択
RetNet(Retentive Network)は、Microsoftの研究チームが2023年7月に発表した「Retentive Network: A Successor to Transformer for Large Language Models」論文で提案されたアーキテクチャである。名称の「Retentive」は、情報を「保持」する仕組みであるRetention(保持)メカニズムに由来する。TransformerのAttentionに代わるRetentionメカニズムにより、学習・推論・チャンク処理の3つの等価な計算モードを提供する。
RetNetの核心であるRetentionは、以下のように定義される:
Retention(X) = (QK^T ⊙ D) V
ここでDは因果的な減衰マスク:D_{nm} = γ^{n-m}(n ≥ m)、0(n < m)。γは減衰率(0 < γ < 1)で、ヘッドごとに異なる値が設定される。
この定式化の重要な点は、同じ計算を3つの異なるモードで等価に実行できることである。
1. 並列モード(学習時):
Retention_parallel = (QK^T ⊙ D) V [O(n²)だがGPUで高速]
Transformerと同様に全トークンを一括処理。GPUのテンソルコアを最大活用し、学習スループットを最大化する。
2. リカレントモード(推論時):
s_n = γ s_{n-1} + K_n^T V_n [状態更新]
Retention_n = Q_n s_n [出力計算]
RNNのように状態ベクトルsを逐次更新。O(1)の計算量とメモリで各トークンを生成でき、自己回帰推論に最適。
3. チャンクワイズモード(長文推論時):
チャンク内: 並列モードで計算
チャンク間: リカレントモードで状態伝播
入力をチャンク(例: 512トークン)に分割し、チャンク内は並列、チャンク間は再帰。長文の推論効率と品質を両立する。
| モデル | パラメータ | PPL (WikiText-103) | 学習速度 | 推論メモリ | 推論速度 |
|---|---|---|---|---|---|
| Transformer | 1.3B | 15.2 | 1.0x | O(n) KV | 1.0x |
| RetNet | 1.3B | 15.1 | 1.0x | O(1) 状態 | 3.5x |
| Transformer | 2.7B | 13.8 | 1.0x | O(n) KV | 1.0x |
| RetNet | 2.7B | 13.7 | 1.0x | O(1) 状態 | 3.4x |
| Transformer | 6.7B | 12.5 | 1.0x | O(n) KV | 1.0x |
| RetNet | 6.7B | 12.4 | 0.97x | O(1) 状態 | 3.2x |
RetNetではマルチヘッドAttentionに相当するMulti-Scale Retentionを採用する:
| 特性 | Transformer | RetNet |
|---|---|---|
| 学習並列性 | 完全並列(O(n²)) | 完全並列(O(n²)) |
| 推論計算量 | O(n)(トークンあたり) | O(1)(トークンあたり) |
| 推論メモリ | O(n)(KVキャッシュ) | O(1)(状態ベクトル) |
| 位置エンコーディング | RoPE/ALiBi | 指数減衰(γ) |
| In-context learning | 優秀 | 良好 |
| 長文処理 | KVキャッシュが制約 | チャンクワイズで効率的 |
| 学習安定性 | 確立済み | GroupNormが重要 |
RetNetの実装にはいくつかの工夫が必要である:
RetNetの発表後、以下の派生研究が進んでいる:
Q1: RetNetはなぜTransformerより速いのか? A: 推論時にリカレントモードを使うことで、各トークンの生成がO(1)の計算量・O(1)のメモリで完了するため。Transformerは過去全トークンのKVキャッシュに対するAttention計算が必要で、コンテキストが長くなるほど遅くなる。RetNetは状態ベクトルのサイズが一定のため、生成速度が系列長に依存しない。
Q2: RetNetの3つのモードはどう使い分けるのか? A: 学習時は並列モード(GPUの並列性を最大活用)、自己回帰推論時はリカレントモード(最小メモリ・最速生成)、長文の入力処理時はチャンクワイズモード(並列性と効率のバランス)を使う。これらは数学的に等価な結果を出力するため、用途に応じて自由に切り替えられる。
Q3: RetNetの限界は何か? A: 主に3点。(1)指数減衰により、非常に遠いトークンの情報が失われやすい(100K+トークンでの性能低下)。(2)Transformerほどのスケーリング実績がなく、最大13Bパラメータにとどまる。(3)FlashAttentionのような成熟した高速化ライブラリが不足している。