17.

Julia Gen.jl 入門完全ガイド — 確率プログラミングフレームワーク(MIT 開発)

編集
この記事の要点
  • Gen.jl は MIT 開発の確率的プログラミング(PPL)フレームワーク
  • @gen マクロで生成モデルを関数として定義、内部で @trace によりランダム変数を記録
  • 推論アルゴリズム: Importance Sampling / Metropolis-Hastings / HMC / Particle Filter
  • 応用例: コンピュータビジョン(3D 物体推定)、ロボティクス、自然言語、政策推論
  • Stan / PyMC との違い: 推論プログラムも自分で書ける柔軟性、研究用途で強い

Gen.jl とは

Gen は MIT Probabilistic Computing Project が開発する Julia 言語の確率プログラミング言語(PPL)です。Stan や PyMC と同じく「確率モデルを書いて推論する」フレームワークですが、Gen は推論プログラム自体も自由に記述できるのが特徴で、研究用途で広く使われています。

確率プログラミングの基本概念

確率プログラミングは「データから未知の変数を推論する」プログラムを確率モデルとして表現します:

  • 生成モデル(Generative Model): データがどう生まれるかを記述
  • 事後推論(Posterior Inference): 観測データから隠れ変数の分布を求める
  • マルコフ連鎖モンテカルロ(MCMC): 事後分布からサンプリング

インストール

# Julia の REPL で
using Pkg
Pkg.add("Gen")

using Gen

最初の例: 線形回帰

using Gen

# 線形回帰の生成モデル
@gen function line_model(xs::Vector{Float64})
    # 傾きと切片を正規分布から
    slope = @trace(normal(0, 1), :slope)
    intercept = @trace(normal(0, 2), :intercept)

    # 各 x に対して y を観測
    n = length(xs)
    for (i, x) in enumerate(xs)
        @trace(normal(slope * x + intercept, 0.1), (:y, i))
    end
end

xs = [-5.0, -4.0, -3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
ys = [-8.4, -7.1, -4.8, -3.7, -2.2, -0.4, 1.6, 3.4, 5.3, 7.1, 9.2]

# 観測値を choicemap に
observations = choicemap()
for (i, y) in enumerate(ys)
    observations[(:y, i)] = y
end

# Importance Sampling で推論
(trace, _) = Gen.importance_resampling(line_model, (xs,), observations, 1000)

# 推定値
println("slope = ", trace[:slope])
println("intercept = ", trace[:intercept])

@gen マクロの仕組み

@gen はマクロで、関数を生成関数(Generative Function)に変換します:

@gen function my_model()
    # @trace でランダム変数を記録
    x = @trace(normal(0, 1), :x)
    y = @trace(bernoulli(0.5), :y)

    # 通常の Julia コードも書ける
    z = x + (y ? 1.0 : 0.0)

    # サブモデル呼び出しも可能
    w = @trace(another_model(), :sub)

    return z + w
end

@trace(分布, アドレス): ランダム変数をアドレス付きで実行履歴(trace)に記録します。

推論アルゴリズム一覧

アルゴリズム関数用途
Importance Samplingimportance_samplingシンプルな事後推論
Importance Resamplingimportance_resampling少数サンプル
Metropolis-Hastingsmetropolis_hastingsMCMC の基本
HMChmc連続変数、勾配利用
MAP 最適化map_optimize最大事後確率推定
Particle Filterparticle_filter時系列・状態空間
変分推論black_box_vi近似ベイズ

Metropolis-Hastings の例

function block_resimulation_update(tr)
    # slope を再サンプル
    (tr, _) = mh(tr, select(:slope))
    # intercept を再サンプル
    (tr, _) = mh(tr, select(:intercept))
    return tr
end

# 初期 trace を取得
(trace, _) = generate(line_model, (xs,), observations)

# MCMC を 1000 回
for i in 1:1000
    trace = block_resimulation_update(trace)
end

println("Final slope: ", trace[:slope])
println("Final intercept: ", trace[:intercept])

応用例: 3D 物体姿勢推定

Gen は逆グラフィックス(Inverse Graphics)研究で多く使われます:

@gen function pose_model()
    # 物体の位置・回転をサンプル
    x = @trace(uniform(-5, 5), :x)
    y = @trace(uniform(-5, 5), :y)
    theta = @trace(uniform(0, 2π), :theta)

    # レンダリング(シミュレータ呼び出し)
    rendered_image = render_scene(x, y, theta)

    # 観測画像との差を尤度として
    for i in 1:length(rendered_image)
        @trace(normal(rendered_image[i], 0.1), (:pixel, i))
    end
end

他 PPL との比較

フレームワーク言語推論得意分野
Gen.jlJuliaカスタム推論可研究、逆グラフィックス
Stan独自 DSLHMC / NUTS 固定統計モデル全般
PyMCPythonNUTS / VIベイズ統計、ビジネス
PyroPython (PyTorch)SVI / HMC深層生成モデル
Turing.jlJuliaHMC / NUTSJulia エコシステム統合

Gen の強み

  • カスタム推論プログラム: 既存のアルゴリズムに縛られず自前で書ける
  • ハイブリッドモデル: 連続変数・離散変数・確率的制御フローを統合
  • Julia のパフォーマンス: C 並みの実行速度
  • 研究フレンドリー: 論文の最新アイデアを試しやすい

学習リソース

  • 公式サイト: gen.dev
  • GitHub: probcomp/Gen.jl
  • Tutorial: GenJulia の公式チュートリアル(線形回帰 → SLAM まで)
  • 論文: Gen: A General-Purpose Probabilistic Programming System with Programmable Inference(PLDI 2019)

FAQ

Q: Gen と Turing.jl はどちらが良い?
A: 統計モデル中心なら Turing が手軽。カスタム推論やシミュレータベースのモデルなら Gen。

Q: Python から呼び出せる?
A: PyJulia で呼び出し可能ですが、推奨は Julia ネイティブ。

Q: GPU で動く?
A: 部分的に。Gen のコア自体は CPU。連続変数の HMC で勾配計算は AD(Zygote.jl)と組み合わせます。

Q: 学習曲線は?
A: Stan / PyMC より急。確率プログラミングと Julia の両方を学ぶ必要があります。

編集
Post Share
子ページ
  1. インストール
同階層のページ
  1. Java
  2. PHP
  3. Python
  4. C#
  5. C++
  6. Ruby
  7. Go
  8. HTML
  9. CSS
  10. JavaScript
  11. TypeScript
  12. VBA
  13. Google Apps Script
  14. Julia
  15. Swift
  16. オブジェクト指向言語共通
  17. Gen