Stan を利用して回帰モデルのパラメーターを推定する方法

回帰モデル

単回帰

単回帰モデルは、応答変数 y を説明変数 x で説明するモデルである。y と x の関係は、y = β1x + β0 + e (e ~ norm(0, σ))、または y ~ norm(β1x + β0, σ) によって記述される。

// lm.stan
data {
  int<lower=0> N;
  vector[N] x;
  vector[N] y;
}
parameters {
  real beta_0;
  real beta_1;
  real<lower=0> sigma;
}
model {
  y ~ normal(beta_1 * x + beta_0, sqrt(sigma));
}

R で Stan コードを呼び出して実行するには、次のようにする。サンプルデータとして y = 2x + e, e ~ norm(0, 16) となるように乱数生成して、Stan コードを実行する。その結果として、x の係数が 1.93、切片が -0.50、分散が 15.80 として推測された。

library(rstan)
N <- 100
x <- runif(N) * 10
y <- x * 2 + rnorm(N, mean = 0, sd = 4)
d <- list(x = x, y = y, N = length(x))
fit <- stan(file = 'lm.stan', data = d)
## Inference for Stan model: lm.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
## 
##           mean se_mean   sd    2.5%     25%     50%     75%   97.5% n_eff Rhat
## beta_0   -0.50    0.02 0.84   -2.16   -1.07   -0.50    0.10    1.11  1372    1
## beta_1    1.93    0.00 0.14    1.67    1.84    1.93    2.03    2.21  1293    1
## sigma    15.80    0.05 2.34   11.88   14.14   15.53   17.22   21.03  2001    1
## lp__   -186.52    0.03 1.25 -189.71 -187.12 -186.23 -185.60 -185.07  1322    1
## 
## Samples were drawn using NUTS(diag_e) at Fri Jan  5 16:44:22 2018.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).

重回帰

説明変数が複数存在するとき、x は行列として与えることができる。このとき、Stan コードに、説明変数の行列の行数と列数に関する定義を追加して対応する。

data {
    int<lower=0> N;
    int<lower=0> K;
    matrix[N,K] x;
    vector[N] y;
}
parameters {
    real beta_0;
    vector[K] beta;
    real<lower=0> sigma;
}
model {
    y ~ normal(x * beta + beta_0, sigma);
}

R で Stan コードを呼び出して実行するには、次のようにする。

library(rstan)
N <- 1000
K <- 5

x <- matrix(0, ncol = K, nrow = N)
for (k in 1:K) {
    x[, k] <- runif(N)
}

y <- x[, 1] + x[, 2] * 2 + x[, 3] * 4 - x[, 4] - 3 * x[, 5] + rnorm(N, mean = 0, sd = 1.5)

d <- list(x = x, y = y, N = nrow(x), K = ncol(x))
fit <- stan(file = 'mlm.stan', data = d)
fit
## Inference for Stan model: mlm.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
## 
##            mean se_mean   sd    2.5%     25%     50%     75%   97.5% n_eff Rhat
## beta_0     0.20    0.00 0.19   -0.16    0.08    0.20    0.33    0.57  2161    1
## beta[1]    1.27    0.00 0.16    0.94    1.16    1.27    1.38    1.59  3347    1
## beta[2]    1.73    0.00 0.16    1.42    1.62    1.74    1.85    2.04  3622    1
## beta[3]    3.80    0.00 0.16    3.47    3.69    3.80    3.91    4.11  4000    1
## beta[4]   -1.09    0.00 0.16   -1.42   -1.20   -1.10   -0.99   -0.77  4000    1
## beta[5]   -3.15    0.00 0.17   -3.48   -3.27   -3.15   -3.03   -2.82  3017    1
## sigma      1.51    0.00 0.03    1.44    1.48    1.50    1.53    1.58  4000    1
## lp__    -907.60    0.04 1.85 -911.86 -908.64 -907.32 -906.18 -904.95  2015    1
## 
## Samples were drawn using NUTS(diag_e) at Fri Jan  5 16:56:45 2018.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).