OpenAI が開発した GPU カーネル記述用の Python ベース DSL であり、CUDA の低レベル知識なしに高性能な並列演算カーネルを記述できるコンパイラ言語。
Triton は OpenAI が開発したオープンソースの GPU プログラミング言語およびコンパイラである。Python に似た高水準の構文で GPU カーネルを記述でき、コンパイラが自動的にメモリアクセスの最適化・ループタイリング・命令スケジューリングを行う。CUDA C++ では数百行必要な最適化カーネルを、Triton では数十行で同等以上の性能で実装できる。
PyTorch 2.0 以降、TorchInductor バックエンドが Triton をデフォルトのコード生成ターゲットとして採用しており、torch.compile 経由で自動的に Triton カーネルが生成・実行される。
Triton の @triton.jit デコレータを付与した Python 関数がカーネル定義となる。テンソル演算をブロック(タイル)単位で記述し、プログラマは論理的なブロックインデックスのみを意識する。
Python AST から Triton 独自の IR に変換される。この IR 上で以下の最適化パスが適用される。
Triton IR から LLVM IR を経由して PTX(NVIDIA GPU アセンブリ)を生成する。Hopper 世代では TTGIR(Triton GPU IR)から直接 NVPTX を生成するパスも追加されている。
Triton で行列乗算を記述する場合、プログラマは出力行列のブロック座標を tl.program_id で取得し、入力行列の対応するタイルを tl.load で読み込み、tl.dot でブロック行列積を計算する。
Triton 版 FlashAttention は以下の手順で動作する。
この実装により、標準的な Attention 演算と比較してメモリ使用量が O(N^2) から O(N) に削減される。
| 特性 | Triton | CUDA C++ | CuPy |
|---|
| 記述量 | 少ない(数十行) | 多い(数百行) | 少ない |
| 性能上限 | cuBLAS の 90-100% | 理論値に近い | cuBLAS の 60-80% |
| 自動最適化 | タイリング・合体 | 手動 | なし |
| デバッグ | Python ベース | printf/NSight | Python ベース |
| 対応 HW | NVIDIA GPU | NVIDIA GPU | NVIDIA GPU |
| 動的シェイプ | JIT 再コンパイル | テンプレート | 対応 |
torch.compile(model) を呼ぶと TorchInductor が FX グラフを解析し、融合可能な演算パターンを Triton カーネルに変換する。プログラマが明示的に Triton コードを書く必要はない。
torch.library.custom_op を使って Triton カーネルを PyTorch のオペレータとして登録できる。これにより torch.compile のグラフ内でカスタムカーネルが呼び出され、前後の演算との融合最適化も適用される。
Triton カーネルに torch.autograd.Function のフォワード/バックワード定義を与えることで、学習ループでも利用可能になる。FlashAttention の学習時バックワードパスも Triton で実装されている。
Triton カーネルの性能はブロックサイズ(BLOCK_M, BLOCK_N, BLOCK_K)に強く依存する。一般的な指針は以下の通り。
triton.autotune デコレータでブロックサイズやパイプラインステージ数の候補を列挙し、実行時に最速の組み合わせを自動選択できる。
いいえ、Triton は CUDA の上位レイヤーに位置する抽象化であり、Warp レベルの細粒度制御が必要な場合(例: Warp Shuffle 命令の活用、非同期コピーの手動制御)は CUDA C++ が依然として必要です。ただし、一般的な GEMM・Attention・ElementWise 演算では Triton で十分な性能が得られます。
TRITON_INTERPRET=1 環境変数を設定すると、Triton カーネルが Python インタプリタ上で逐次実行されるため、通常の Python デバッガ(pdb, breakpoint)が使えます。性能プロファイリングには NVIDIA Nsight Compute が利用可能です。
NVIDIA GPU(Compute Capability 7.0 以上、Volta 世代以降)と CUDA Toolkit 11.4 以上が必要です。PyTorch 2.0 以降をインストールすれば Triton は自動的に含まれます。RTX 3060 以上であれば、torch.compile 経由で Triton の恩恵を受けられます。