7.

JAX とは?Google 製の関数型数値計算フレームワーク

編集

本稿は JAX (ジャックス) に関する記事です。

この記事の要点
  • JAX は "NumPy 互換 API + 自動微分 + GPU/TPU + JIT" を統合した数値計算フレームワーク
  • Google が開発・Google DeepMind の研究の主軸。Gemini など先端モデルの学習でも採用
  • 関数型のスタイル(純関数 + 変換)が独特。慣れが必要だが性能・並列化が強力
  • jit / grad / vmap / pmap という関数変換で書く
  • JAX 単体は低レベル。実用ではニューラルネット用上位ライブラリ Flax / Equinox を併用
  • TPU を活かせる事実上唯一の主要フレームワーク(PyTorch も対応進行中)

JAX とは?

JAX は、Google が開発する数値計算ライブラリです。「NumPy のような API」「自動微分」「GPU / TPU での実行」「XLA による JIT コンパイル」「並列化変換」を、関数を変換する形で統一的に提供します。

PyTorch / TensorFlow が「フレームワーク」だとすれば、JAX は「NumPy の強化版」に近い思想で設計されており、ニューラルネット以外の数値最適化・物理シミュレーション・確率モデル等にも自然に使えます。深層学習向けの上位ライブラリ (Flax / Equinox / Haiku) を上に積んで使うのが一般的です。

主な特徴

特徴説明
NumPy 互換 APIjax.numpy as jnp がほぼ NumPy と同じ呼び方
自動微分 (grad)任意関数の偏微分・高階微分を関数変換として取得
JIT コンパイル (jit)XLA で関数全体を最適化。GPU/TPU でも高速
ベクトル化 (vmap)ループ不要でバッチ並列化
デバイス並列 (pmap / pjit)複数 GPU / TPU 跨ぎの並列実行
純関数モデル副作用なしの関数を変換していくスタイル
TPU 親和性Google TPU で性能が最も出やすい
関数型エコシステムFlax / Optax / Orbax / Chex 等のモジュール群

4 つの主要変換

変換役割直感
jit(fn)関数を XLA でコンパイル「同じ計算を爆速にする」
grad(fn)勾配を返す関数を生成「微分してくれる」
vmap(fn)関数を自動でバッチ化「for ループを並列化」
pmap(fn) / pjit複数デバイス間で並列実行「複数 GPU/TPU に分散」

最小サンプル: NumPy 風+自動微分+JIT

import jax
import jax.numpy as jnp

# NumPy 風に書ける
def f(x):
    return jnp.sum(jnp.sin(x) ** 2)

# 勾配を取る関数を生成
grad_f = jax.grad(f)

# JIT コンパイル
fast_grad_f = jax.jit(grad_f)

x = jnp.array([0.0, 1.0, 2.0])
print(fast_grad_f(x))

関数型スタイルの注意点

PyTorch / NumPy 出身者が戸惑うポイント
  • 配列はイミュータブル: x[0] = 1 はできない。x = x.at[0].set(1) を使う
  • 純関数でないと jit / vmap が破綻する (外部状態・Python の制御フロー注意)
  • Python の for 内で副作用を出すと、JIT 時にトレースが効かない
  • ランダム数は明示的な PRNG キー (jax.random.PRNGKey) を渡す
  • シェイプ依存の分岐は jax.lax.cond / jax.lax.fori_loop を使う
  • NaN や形状不一致はコンパイル後に出るので、jax.config.update("jax_disable_jit", True) でデバッグ

深層学習用上位ライブラリ

ライブラリ立ち位置
Flax (Linen / NNX)Google 公式。Gemini 等の社内利用実績。現在は NNX が新世代
EquinoxPyTorch ライクな素直さ。クラスで層を書く
HaikuDeepMind 製。Flax の前身的存在 (現在は保守モード)
Optax最適化アルゴリズム (Adam / AdamW / Lion 等)
Orbaxチェックポイント保存・配信
Chexテスト・アサーションのユーティリティ
NumPyro確率プログラミング (Pyro の JAX 版)
Diffrax微分方程式ソルバ
jaxtypingテンソルの形状・型を注釈

PyTorch / TensorFlow との立ち位置

観点 JAX PyTorch TensorFlow
パラダイム関数型命令型 (eager)命令型 (eager) + Keras
TPU 親和性最強対応進行中従来からの定番
研究での採用Google 系・確率プログラミング圧倒的多数減少傾向
本番デプロイ少ない (JAX2TF 経由)TorchServe / ONNXTF Serving / TFLite / TF.js
学習コスト高め (関数型 + 純関数)低い
科学計算・物理得意

使うのが向くシーン

  • TPU を活用したい(Google Cloud TPU、Colab TPU)
  • 大規模分散学習を関数変換で表現したい
  • 確率プログラミング・科学計算・物理シミュレーション
  • 研究プロトで論文の実装を NumPy 風に純粋に書きたい
  • Google DeepMind / Google Brain 系の論文を再現したい

インストール

# CPU 版
pip install jax

# NVIDIA GPU 版 (CUDA 12)
pip install -U "jax[cuda12]"

# TPU (Google Cloud / Colab TPU)
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# よく一緒に入れるもの
pip install flax optax

運用上のヒント

Tips
  • 最初は jit 無しで書いて挙動を確認 → 安定したら jit を被せる
  • vmap を使うクセ: バッチ次元のループを書きたくなったらまず vmap
  • 勾配が大きい / NaN が出る場合は jax.debug.nan_checkjax.debug.print
  • 大規模モデルは FSDP 相当pjit / shard_map でテンソル分割
  • チェックポイントは Orbax。学習中断・再開や非同期保存
  • PyTorch コードを移植する場合、まず NumPy 部分を jnp に置換 → jit 化のステップを踏む

注意点

よくある落とし穴
  • 初回 JIT コンパイルが遅い: 起動時のキャッシュコストを許容できる用途を選ぶ
  • 動的な形状でループするとコンパイルし直しになる。形状を固定化する設計
  • イミュータブル配列を理解しないと初学者が詰まる
  • 本番デプロイのエコシステムが PyTorch / TF より薄い
  • GPU でのデバイス・メモリ管理 (XLA_PYTHON_CLIENT_PREALLOCATE 等) が独特
  • Windows ネイティブの GPU サポートは限定的。Linux / WSL2 を推奨
  • JAX とそのエコシステム (Flax / Optax) はバージョン整合が崩れやすい。lock ファイルで固定

関連

編集
Post Share
子ページ

子ページはありません

同階層のページ
  1. PyTorch
  2. TensorFlow(テンソルフロー)
  3. scikit-learn
  4. Hugging Face Transformers
  5. LangChain
  6. LlamaIndex
  7. JAX
  8. ONNX Runtime

最近更新/作成されたページ