Stan を利用して回帰モデルのパラメーターを推定する方法

単回帰モデル

このページでは、Stan を使用して簡単な回帰モデルのパラメーターを推定する例を示す。ここで、サンプルデータとして、trees と呼ばれているデータセットを使用する。このデータセットには 31 本の桜の木の周長(Grith)、高さ(Height)、および容積(Volume)のデータが記録されている。樹木が高くなるには、それを支えるために幹も太くなる必要がある。そのため、幹が太さを用いて、木の高さを推定できるかもしれない。そこで、この仮定をもとに、Grith を説明変数とし、Height を応答変数として、回帰モデルを構築する。

data(trees)
head(trees)
##   Girth Height Volume
## 1   8.3     70   10.3
## 2   8.6     65   10.3
## 3   8.8     63   10.2
## 4  10.5     72   16.4
## 5  10.7     81   18.8
## 6  10.8     83   19.7

library(ggplot2)
g <- ggplot(trees, aes(x = Girth, y = Height)) + geom_point()
print(g)
R の trees データセット

Stan コードによるモデル記述

単回帰モデルは、応答変数 y を説明変数 x で説明するモデルである。y と x の関係は次の関係がたり立つ。

\[ y = \mathcal{N}\left(\mu, \sigma \right)\] \[ \mu = \beta_{1}x + \beta_{0} \]

このモデルの入力データは x (Girth) と y (Height) であり、パラメーターは β0、β1 および σ である。x および y は、後から与える入力データであるので data ブロックで定義する。β0、β1 および σ は、これから MCMC サンプリングを通して推定したいパラメーターであるので、parameters ブロックで定義する。また、回帰モデルにおける信頼区間および予測区間を計算するために、推定されたパラメーター(β0、β1 および σ)を用いて、Girth が 5-30 の範囲にあるときの木の高さをサンプリングするコードを generated quantities ブロックに記述する。

// lm.stan
data {
  int N;
  vector[N] x;
  vector[N] y;

  int new_N;
  vector[new_N] new_x;
}
parameters {
  real beta_0;
  real beta_1;
  real<lower=0> sigma;
}
transformed parameters {
  real mu[N];
  for (i in 1:N) {
    mu[i] = beta_1 * x[i] + beta_0;
  }
}
model {
  for (i in 1:N) {
    y[i] ~ normal(mu[i], sigma);
  }
}
generated quantities {
  real yhat[new_N];
  real muhat[new_N];
  for (i in 1:new_N) {
    muhat[i] = beta_1 * new_x[i] + beta_0;
    yhat[i] = normal_rng(muhat[i], sigma);
  }
}

パラメーター推定

R からデータを Stan モデルに代入するとき、R でデータをリストとして用意する必要がある。Stan コードで記述された data ブロックには、xyNnew_xnew_N の変数が定義されている。R で準備するリストにもこれらの名前をつける必要がある。データを準備できたら、これを stan 関数に代入してパラメーター推定を行う。stan 関数が実行されると、MCMC サンプリング結果が返される。この結果に、parameters および generated quantities ブロックで定義されているパラメーターはの事後分布が含まれる。

library(rstan)
d <- list(x = trees$Girth, y = trees$Height, N = nrow(trees),
          new_x = seq(5, 30, 0.1), new_N = length(seq(5, 30, 0.1)))

fit <- stan(file = 'lm.stan', data = d)
fit
## Inference for Stan model: lm.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
## 
##             mean se_mean   sd   2.5%    25%    50%    75%  97.5% n_eff Rhat
## beta_0     62.07    0.12 4.54  52.83  59.10  62.14  65.01  71.07  1516    1
## beta_1      1.05    0.01 0.33   0.39   0.84   1.04   1.27   1.71  1511    1
## sigma       5.79    0.02 0.81   4.46   5.20   5.70   6.28   7.56  1643    1
## yhat[1]    67.19    0.12 6.59  53.71  62.94  67.26  71.54  80.32  3177    1
## yhat[2]    67.54    0.12 6.56  54.67  63.20  67.52  71.82  80.49  3224    1
## yhat[3]    67.65    0.12 6.54  54.79  63.25  67.64  71.90  80.39  3186    1
## ...
## yhat[248]  93.30    0.17 7.96  77.51  88.01  93.41  98.63 108.75  2304    1
## yhat[249]  93.32    0.16 8.10  76.58  87.98  93.61  98.73 108.88  2465    1
## yhat[250]  93.70    0.16 7.98  78.19  88.33  93.88  99.15 108.94  2537    1
## yhat[251]  93.74    0.16 8.16  77.65  88.26  93.76  99.21 109.42  2561    1
## ...
## muhat[1]    67.11    0.09 3.16  60.95  65.08  67.07  69.14  73.27  1125   1
## muhat[2]    67.22    0.09 3.13  61.12  65.21  67.18  69.22  73.33  1127   1
## muhat[3]    67.33    0.09 3.09  61.30  65.34  67.29  69.31  73.37  1129   1
## muhat[4]    67.44    0.09 3.06  61.46  65.47  67.41  69.40  73.40  1131   1
## ...
## muhat[249]  93.81    0.18 5.98  82.25  89.97  93.88  97.70 105.39  1078   1
## muhat[250]  93.92    0.18 6.02  82.28  90.05  93.98  97.84 105.56  1078   1
## muhat[251]  94.03    0.18 6.05  82.31  90.12  94.08  97.96 105.74  1078   1
## lp__       -67.48    0.05 1.38 -70.85 -68.11 -67.10 -66.50 -65.96   851   1

fit 変数にはパラメーター推定時の設定条件およびパラメーターの推定結果が記録されている。例えば、この場合、4 つの chain で、それぞれ 2000 回のサンプリングが行われたことがわかる。また、2000 回のサンプリングのうち、最初の 1000 回はウォーミングアップのステップで、パラメーター推定時に使われていない。また、thin は間引きの間隔を表し、thin=1 の場合は、間引きが行われていなく、サンプリング結果をすべてを用いてパラメーター推定が行われる。例えば、thin=2 の場合、サンプリング結果のうちパラメーター推定に使用するのは 1 個おきとなる。

推定されたパラメーターの要約統計量は結果の中間部分に表示される。上の例では、beta_0beta_1sigma が推定されたパラメーターの要約統計量である。また、muhat および yhat は、generated quantities ブロックで再サンプリングされた値の要約統計量である。これらのサンプリング結果は extract 関数を使用して、行列または配列の形で取得できる。

ベイズ信用区間

信用区間とは、あるパラメーター(例えば母平均)の事後分布において、真の値が 1-α % の確率で含まれる区間である。Stan でモデルを構築したときに generated quantities ブロックで母平均 muhat の事後分布をサンプリングしたので、このパラメーターを取得すれば、95% 信用区間を求めることができるようになる。

library(dplyr)
library(ggplot2)
ms <- rstan::extract(fit)

df <- data.frame(d$new_x, t(apply(ms$muhat, 2, quantile, probs = c(0.025, 0.500, 0.975))))
colnames(df) <- c('x', 'lower', 'median', 'upper')

g <- ggplot(df, aes(x = x))
g <- g + geom_ribbon(aes(ymin = lower, ymax = upper), fill = '#000000', alpha = 0.4)
g <- g + geom_line(aes(y = median))
g <- g + geom_point(data = trees, aes(x = Girth, y = Height))
g <- g + xlab('Girth') + ylab('Height')
print(g)
ベイズ信用区間

ベイズ予測区間

generated quantities ブロックでは、5-30 までの Girth を与えて、Height をサンプリングするように記述した。このサンプリング結果が yhat に保存される。ここで、extract 関数を用いて yhat のデータを取り出して、95% 予測区間を求める。

df <- data.frame(d$new_x, t(apply(ms$yhat, 2, quantile, probs = c(0.025, 0.500, 0.975))))
colnames(df) <- c('x', 'lower', 'median', 'upper')

g <- ggplot(df, aes(x = x))
g <- g + geom_ribbon(aes(ymin = lower, ymax = upper), fill = '#000000', alpha = 0.4)
g <- g + geom_line(aes(y = median))
g <- g + geom_point(data = trees, aes(x = Girth, y = Height))
g <- g + xlab('Girth') + ylab('Height')
print(g)
ベイズ予測区間