Google が開発した機械学習専用コンパイラで、TensorFlow や JAX の計算グラフを TPU・GPU・CPU 向けに最適化コンパイルする Accelerated Linear Algebra の略称。
XLA(Accelerated Linear Algebra)は Google が開発したドメイン固有コンパイラであり、線形代数演算を中心とする機械学習ワークロードを各種ハードウェア向けに最適化コンパイルする。TensorFlow のデフォルトコンパイラバックエンドとして組み込まれ、JAX では唯一の実行エンジンとして機能する。
Google の TPU 上で LLM を学習・推論する場合、XLA は不可欠な存在である。Gemini・PaLM・T5 といった大規模モデルの学習はすべて XLA を通じて TPU クラスタ上で実行されている。
XLA の中間表現である HLO IR は、行列乗算・畳み込み・リダクションなどの高水準演算をノードとする計算グラフである。フレームワーク(TensorFlow, JAX, PyTorch/XLA)からのコードはまず HLO に変換される。
HLO レベルで適用される主要な最適化は以下の通り。
XLA の HLO を標準化した StableHLO は、フレームワーク間のポータビリティを向上させる取り組みである。MLIR(Multi-Level Intermediate Representation)上に構築され、異なるコンパイラバックエンド間での互換性を提供する。
| コンパイルステージ | 処理内容 | 最適化例 |
|---|---|---|
| HLO 生成 | フレームワークからの変換 | 型推論・形状推論 |
| HLO 最適化 | グラフレベル最適化 | 融合・定数畳み込み・CSE |
| レイアウト割り当て |
| メモリ配置決定 |
| TPU MXU 対応タイリング |
| バッファ割り当て | メモリ管理 | Liveness 解析・再利用 |
| コード生成 | デバイス固有コード出力 | PTX / TPU 命令 / LLVM IR |
TPU v4/v5 の MXU は 128x128 の行列乗算を 1 サイクルで実行する。XLA はテンソル演算を 128x128 タイルに自動分割し、MXU の利用率を最大化する。パディング挿入やテンソル転置の自動挿入もコンパイラが担当する。
TPU v4 以降はチップ内に 2 つのコアを持つ Megacore 構成となっている。XLA は計算グラフを自動的に 2 分割し、コア間の通信を最小化する配置を決定する。
TPU Pod 内のチップ間通信は ICI 経由で行われる。XLA の SPMD パーティショナが AllReduce・AllGather などの集団通信を最適なトポロジに配置し、通信レイテンシを削減する。
JAX は XLA をネイティブの実行エンジンとして使用し、以下の機能を提供する。
Python 関数を XLA コンパイル対象としてマークする。初回呼び出し時にトレース → HLO 変換 → コンパイルが行われ、以降はコンパイル済みカーネルが直接実行される。
データ並列・テンソル並列の両方を XLA の SPMD パーティショナ経由で実現する。PartitionSpec でテンソルの分割軸を宣言すると、XLA が自動的に通信プリミティブを挿入する。
自動微分も XLA の HLO レベルで処理される。フォワードパスの HLO からバックワードパスの HLO を自動生成し、両パスを通じた最適化(勾配チェックポインティングの自動挿入など)が適用される。
PyTorch モデルを XLA バックエンド上で実行するためのブリッジライブラリ。torch_xla.core.xla_model 経由でテンソルを XLA デバイスに配置すると、PyTorch の Eager 実行が XLA の Lazy Tensor 方式に切り替わる。
XLA は Google エコシステムの中核コンパイラとして位置付けられ、StableHLO を通じて外部コンパイラとの相互運用が進んでいる。IREE(Intermediate Representation Execution Environment)は StableHLO を入力として受け取り、モバイル・エッジデバイス向けのコード生成を行う。
はい、XLA は NVIDIA GPU(CUDA バックエンド)にも対応しています。ただし、NVIDIA GPU 上では TensorRT や Triton(torch.compile 経由)の方が成熟度が高く、一般的には TPU 上での利用が XLA の主戦場です。JAX を NVIDIA GPU で使う場合は自動的に XLA の CUDA バックエンドが使用されます。
XLA のコンパイルキャッシュ(XLA_FLAGS=--xla_gpu_persistent_compilation_cache_dir=/path)を有効にすると、同じ計算グラフの再コンパイルを回避できます。また、JAX では jax.jit のトレース結果をキャッシュする AOT(Ahead-of-Time)コンパイルモードも利用可能です。
はい、JAX をインストールすれば NVIDIA GPU 上で XLA を利用できます(pip install jax[cuda12])。ただし、XLA の真価は TPU クラスタでの大規模分散学習にあるため、小規模実験には torch.compile + Triton の方が手軽です。Google Colab の無料 TPU ランタイムで XLA + JAX を体験するのも良い選択肢です。