元バイオ系

元バイオウェット系がデータサイエンスやらを勉強していくブログ。 基本自分用のまとめ。

Julia tips #6: Multiple-Try Metropolisを実装してみた

Multiple-Try Metropolis (MTM)って?

メトロポリスヘイスティングスの多数決バージョンです。
今いる場所から次の候補を複数生成し、その値周辺が今いる場所より尤もらしければ候補の中の1つへ移動します。

細かい話を書く元気がないので、理論的な背景は気が向いたらまたまとめます。
というわけで実装して行きましょう。

MTMアルゴリズム

任意の提案関数 T(x,y)を定義しT(x, y) > 0 \Longleftrightarrow T(y, x) > 0を満たすとする。
また、\lambda (x,y)\lambda (x,y) > 0で対称な任意の関数であるとする。
マルコフ連鎖の現在の状態を\boldsymbol{x_t}とおくと、MTMアルゴリズムは次のようになる。

  1. T(\boldsymbol{x_t}, \cdot)から\boldsymbol{y_1}, \dots, \boldsymbol{y_k}を独立に生成する

  2. 重みw(\boldsymbol{y_j}, \boldsymbol{x_t}) = \pi(\boldsymbol{y_j})T(\boldsymbol{y_j}, \boldsymbol{x_t})\lambda(\boldsymbol{y_j}, \boldsymbol{x_t})に従って{\boldsymbol{y_1}, \dots, \boldsymbol{y_k} }から\boldsymbol{y}をサンプリングして計算する。

  3. T (\boldsymbol{y}, \cdot)から\boldsymbol{x_1^*}, \dots, \boldsymbol{x_{k-1}^*}をサンプリングし、\boldsymbol{x_k^*}=\boldsymbol{x_t}とする。

  4. yを以下の確率で受容する。  {\displaystyle \alpha = \mathrm{min} \left\{1, \frac{\Sigma_{k} w(\boldsymbol{y_k}, \boldsymbol{x_t})}{\Sigma_{k} w(\boldsymbol{x_k}, \boldsymbol{y})} \right\} }

今回はベイズ計算統計学(統計解析スタンダード)p93 例3.7を実装してみた。提案関数と対称関数は以下のようになっている。

$$T(x, y) = x_t + \mathcal{N}\left(0, 5\sqrt{2}\right)$$ $$\lambda(\boldsymbol{y}, \boldsymbol{x})=\left(\frac{T(x, y)+T(y, x)}{2}\right)^{-1}$$

目的関数は、

$$\frac{1}{3} \mathcal{N}(-5,1) + \frac{2}{3} \mathcal{N}(5, 1)$$

である。

使うパッケージのインポート

using Distributions
using StatsBase
import Plots; plt = Plots; plt.pyplot()
using LaTeXStrings

# グラフ周りの設定
fntsm = plt.font("serif")
fntlg = plt.font("serif", 20)
plt.default(titlefont=fntlg, guidefont=fntsm, tickfont=fntsm, legendfont=fntsm)
plt.default(size=(800, 600)) 

目標分布の定義

$$\frac{1}{3} \mathcal{N}(-5,1) + \frac{2}{3} \mathcal{N}(5, 1)$$

const μ = [-5.0, 5.0]
const σ = [1.0, 1.0] 

function target(x::Array{Float64, 1})
    arr = zeros(length(x))
    Normal_val_1(x) = 1/sqrt(2π*σ[1]^2)*exp(-(x-μ[1])^2/(2σ[1]^2))
    Normal_val_2(x) = 1/sqrt(2π*σ[2]^2)*exp(-(x-μ[2])^2/(2σ[2]^2))
    for i in 1:length(x)
        arr[i] = 1/3 * Normal_val_1(x[i]) + 2/3 * Normal_val_2(x[i])
    end
    
    return arr
end

function target(x::Float64)
    Normal_val_1(x) = 0.45 * 1/sqrt(2π*σ[1]^2)*exp(-(x-μ[1])^2/(2σ[1]^2))
    Normal_val_2(x) = 0.45 * 1/sqrt(2π*σ[2]^2)*exp(-(x-μ[2])^2/(2σ[2]^2))
    val = 1/3 * Normal_val_1(x) + 2/3 * Normal_val_2(x)
    return val
end

x = Array{Float64}(linspace(-10, 10, 1000))
y = target(x)
l = plt.@layout [a b; c{0.2h}]
eq = L"$\frac{1}{3} \mathcal{N}(-5,1) + \frac{2}{3} \mathcal{N}(5, 1)$"
p = plt.plot(x, y, title="Target distribution", lab=eq)

f:id:hotoke-X:20180131003314p:plain

実装

まず、MTMに必要な関数やら値を保持する型を作っとく。

mutable struct Sampler
    x_present::Float64   # 現在の値
    target::Function   # 目標分布関数
    propose_next::Function   # 提案関数
    q::Function   # (x_t+1, x_t)の順で渡す   # 推移核(上で書いたT(x, y))
    λ::Function   # 対称関数
    burn_in::Int   # 焼きなましの回数
    count::Int   # 何回目の更新か保持
    random_state::MersenneTwister   # 乱数ジェネレータ
end

MTM本体

module MTM
using StatsBase   
    function present(sampler)   # 現在の値を返す関数
        return sampler.x_present
    end

    function next(sampler, n_candidate)   # MTM
        w_x2y = zeros(n_candidate)
        w_y2x = zeros(n_candidate)
        proposed = zeros(n_candidate)
        proposed_r = zeros(n_candidate)
        accept = 0
    
        # Σⱼ w(yⱼ,xₜ)
        for i in 1:n_candidate
            # q(yⱼ|xₜ)でyⱼをサンプリング
            proposed[i] = sampler.propose_next(sampler.x_present)
        
            # q(yⱼ|xₜ) と q(xₜ|yⱼ) を計算
            q_x2y = sampler.q(proposed[i], sampler.x_present)
            q_y2x = sampler.q(sampler.x_present, proposed[i])
        
            # w(yⱼ, x) = π(yⱼ)q(yⱼ|xₜ)λ(yⱼ, x)
            w_x2y[i] = sampler.target(proposed[i]) * q_x2y * sampler.λ(q_x2y, q_y2x)
        end
        
    
    
        # w(y,xₜ)したがってyをサンプリング
        y = sample(proposed, Weights(w_x2y))
    
        # Σⱼ w(xₜ, yⱼ)
        for i in 1:(n_candidate-1)
            # q(xⱼ|y)でxⱼをサンプリング
            proposed_r[i] = sampler.propose_next(y)
        
            # q(xⱼ|y) と q(y|xⱼ) を計算
            q_y2x = sampler.q(proposed_r[i], y)
            q_x2y = sampler.q(y, proposed_r[i])
        
            # w(xⱼ, y) = π(xⱼ)q(xⱼ|y)λ(xⱼ, y)
            w_y2x[i] = sampler.target(proposed_r[i]) * q_y2x * sampler.λ(q_y2x, q_x2y)
        end
    
        # set xₘ=x
        proposed_r[end] = sampler.x_present
    
        # q(xₘ|y) と q(y|xₘ) を計算
        q_y2x = sampler.q(proposed_r[end], y)
        q_x2y = sampler.q(y, proposed_r[end])
        w_y2x[end] = sampler.target(proposed_r[end]) * q_y2x * sampler.λ(q_y2x, q_x2y) 
        
    
        # αの計算
        odds = sum(w_x2y)/sum(w_y2x)
        
        # 受容するか棄却するか決定
        if rand(sampler.random_state) < odds
            sampler.x_present = y
            accept = 1
            return y, accept
        else
            return sampler.x_present, accept
        end
    end

    function burn_in(sampler, n_candidate)   # 焼きなまし期間
        arr = zeros(sampler.burn_in)
        accept = zeros(sampler.burn_in)
        for i in 1:sampler.burn_in
            arr[i], accept[i] = next(sampler, n_candidate)
            sampler.count += 1
        end
        return arr
    end

    function run(sampler, n_steps, n_candidate)   #好きなだけ回す
        arr = zeros(n_steps)
        accept = zeros(n_steps)
        for i in 1:n_steps
            arr[i], accept[i] = next(sampler, n_candidate)
            sampler.count += 1
        end
        return arr, accept
    end

end

上で作った MTMを動かすところ。

const σ_proposal = 5*sqrt(2)  


const seed = 0
rng = srand(seed)   # 乱数シードの固定とジェネレータの取得
gaussian(xₜ) = Normal(xₜ, σ_proposal^2)   # 推移核の分布関数
generate(xₜ) = rand(gaussian(xₜ))   #  推移核の分布関数からサンプリングする関数
uniform = Uniform(-10, 10)   # 一様分布関数
inits = rand(uniform)   # 初期値の生成

target = target   # 意味ないけど、見た目のため
q(y, xₜ) =  pdf(gaussian(xₜ), y)
λ(x, y) = ((x+y)/2)^(-1) # MTM-inv scheme(下の参考文献参照)

const burn_in = 1000000
const num_sample = 100000
cnt = 0

sampler_MTM = Sampler(inits, target, generate, q, λ, burn_in, cnt, rng)

MTMの実行

ベイズ計算統計学(統計解析スタンダード)p93 例3.7の例と同じように、候補数を1, 5, 10と変えてみた。
候補数1は単純なメトロポリスヘイスティングス法になります。

焼きなまし

const n_candidate=10   # MTMでつかう候補の数(ここでは10)
samples_burn_in = MTM.burn_in(sampler_MTM, n_candidate)

本番。ステップごとのサンプルと受容フラグが返ってくる。

samples, accept = MTM.run(sampler_MTM, num_sample, n_candidate)

結果の確認

n_slice = 3000
lags =  collect(1:200)
auto_correlation = autocor(samples[end-n_slice:end], lags)


layout = plt.@layout [a b; c]

p1 = plt.bar(auto_correlation, title="Auto correlation",
                lab="", xlab="Lag", ylab="Auto\ncorrelation (A.U.)", 
                xlim=(-10,200), ylim=(-0.1,1))
p2 = plt.plot(samples[end-n_slice:end], title="Trajectory", lab="")
p3 = plt.histogram(samples , title="Target distribution", lab="", norm=true, bins=128)
p3 = plt.plot!(x,y, color="magenta", lab="")

p=plt.plot(p1, p2, p3, layout=layout)
print(sum(accept[end-n_slice:end]) / length(accept[end-n_slice:end]))   # 採択率を算出

f:id:hotoke-X:20180131010445p:plain
候補数1(採択率:0.04798400533155615)

f:id:hotoke-X:20180131010726p:plain
候補数5(採択率:0.20626457847384205)

f:id:hotoke-X:20180131010851p:plain
候補数10(採択率:0.31756081306231254)

できました。
候補を増やすと、目標分布の2つの山を頻繁に行き来しているのがわかります。
採択確率も候補を増やしたほうが良くなっています。
ただ、候補の数分計算量が増えるので注意ですね。

まとめ

実装は簡単だったが、はてなブログに書くのが非常に疲れt
間違いがあったらご指摘ください。
ちなみに以前書いたメトロポリスヘイスティングスは少し間違ってます(笑)

参考文献

Pandolfi, S., Bartolucci, F. & Friel, N. A generalized multiple-try version of the Reversible Jump algorithm. Comput. Stat. Data Anal. 72, 298–314 (2014).