Stan を利用して回帰モデルのパラメーターを推定する方法
回帰モデル
単回帰
単回帰モデルは、応答変数 y を説明変数 x で説明するモデルである。y と x の関係は、y = β1x + β0 + e (e ~ norm(0, σ))、または y ~ norm(β1x + β0, σ) によって記述される。
// lm.stan
data {
int<lower=0> N;
vector[N] x;
vector[N] y;
}
parameters {
real beta_0;
real beta_1;
real<lower=0> sigma;
}
model {
y ~ normal(beta_1 * x + beta_0, sqrt(sigma));
}
R で Stan コードを呼び出して実行するには、次のようにする。サンプルデータとして y = 2x + e, e ~ norm(0, 16) となるように乱数生成して、Stan コードを実行する。その結果として、x の係数が 1.93、切片が -0.50、分散が 15.80 として推測された。
library(rstan)
N <- 100
x <- runif(N) * 10
y <- x * 2 + rnorm(N, mean = 0, sd = 4)
d <- list(x = x, y = y, N = length(x))
fit <- stan(file = 'lm.stan', data = d)
## Inference for Stan model: lm.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
##
## mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
## beta_0 -0.50 0.02 0.84 -2.16 -1.07 -0.50 0.10 1.11 1372 1
## beta_1 1.93 0.00 0.14 1.67 1.84 1.93 2.03 2.21 1293 1
## sigma 15.80 0.05 2.34 11.88 14.14 15.53 17.22 21.03 2001 1
## lp__ -186.52 0.03 1.25 -189.71 -187.12 -186.23 -185.60 -185.07 1322 1
##
## Samples were drawn using NUTS(diag_e) at Fri Jan 5 16:44:22 2018.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).
重回帰
説明変数が複数存在するとき、x は行列として与えることができる。このとき、Stan コードに、説明変数の行列の行数と列数に関する定義を追加して対応する。
data {
int<lower=0> N;
int<lower=0> K;
matrix[N,K] x;
vector[N] y;
}
parameters {
real beta_0;
vector[K] beta;
real<lower=0> sigma;
}
model {
y ~ normal(x * beta + beta_0, sigma);
}
R で Stan コードを呼び出して実行するには、次のようにする。
library(rstan)
N <- 1000
K <- 5
x <- matrix(0, ncol = K, nrow = N)
for (k in 1:K) {
x[, k] <- runif(N)
}
y <- x[, 1] + x[, 2] * 2 + x[, 3] * 4 - x[, 4] - 3 * x[, 5] + rnorm(N, mean = 0, sd = 1.5)
d <- list(x = x, y = y, N = nrow(x), K = ncol(x))
fit <- stan(file = 'mlm.stan', data = d)
fit
## Inference for Stan model: mlm.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
##
## mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
## beta_0 0.20 0.00 0.19 -0.16 0.08 0.20 0.33 0.57 2161 1
## beta[1] 1.27 0.00 0.16 0.94 1.16 1.27 1.38 1.59 3347 1
## beta[2] 1.73 0.00 0.16 1.42 1.62 1.74 1.85 2.04 3622 1
## beta[3] 3.80 0.00 0.16 3.47 3.69 3.80 3.91 4.11 4000 1
## beta[4] -1.09 0.00 0.16 -1.42 -1.20 -1.10 -0.99 -0.77 4000 1
## beta[5] -3.15 0.00 0.17 -3.48 -3.27 -3.15 -3.03 -2.82 3017 1
## sigma 1.51 0.00 0.03 1.44 1.48 1.50 1.53 1.58 4000 1
## lp__ -907.60 0.04 1.85 -911.86 -908.64 -907.32 -906.18 -904.95 2015 1
##
## Samples were drawn using NUTS(diag_e) at Fri Jan 5 16:56:45 2018.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).