Stan basics

Outline

Topics

  • Anatomy of a simple Stan program.
  • Interpreting the output of Stan.

Rationale

We will write many Stan models in the coming weeks, so we will cover here the key concepts needed to write in Stan. Similarities with simPPLe will help in this process.

Example

Review: code in simPPLe and analytic answer

Let us revisit the Doomsday model, which we wrote in simPPLe as follows:

source("../../solutions/simple.R")

doomsday_model = function() {
  x = simulate(Unif(0, 5))
  observe(0.06, Unif(0, x))
  return(x)
}

posterior(doomsday_model, 100000)
[1] 1.105568

Recall the analytic answer that we computed in clicker questions was: \(\approx 1.117\).

Code in Stan

Today we will re-write this model in Stan instead of simPPLe. The main difference is that:

  • both simulate and observe will be denoted by ~ in Stan,
  • to differentiate between observed and latent, Stan instead uses variable declarations (as data or parameters).

First, put the following code into a file called doomsday.stan:

doomsday.stan
1data { //
  real y; 
}

2parameters { //
  real<lower=0, upper=5> x;
  real<lower=0, upper=1> x2;
}

model {
3  x ~ uniform(0, 5); //
4  y ~ uniform(0, x); //
}
1
Variables introduced in a data block will be treated as observed.
2
Variables introduced in a parameters block will be treated as latent, i.e., unobserved.
3
Since x is declared latent (i.e., inside a parameters block), Stan will know to treat this as the counterpart of simPPLe’s x = simulate(Unif(0, 5)).
4
Since y is declared observed (i.e., inside a data block), Stan will know to treat this as the counterpart of simPPLe’s observe(0.06, Unif(0, x)).

Question: How does Stan get the actual observed value \(y = 0.06\)?

Answer: When Stan gets called from R:

suppressPackageStartupMessages(require(cmdstanr))
mod = cmdstan_model("doomsday.stan")
fit = mod$sample(
  seed = 1,
  chains = 1,
  refresh = 50000,
  output_dir = "stan_out",
1  data = list(y = 0.06),
  iter_warmup = 1000,
2  iter_sampling = 100000,
)
1
Pass in the values taken by the observed random variables.
2
The number of samples to compute (a cousin of \(M\) in our Monte Carlo notation)
Running MCMC with 1 chain...

Chain 1 Iteration:      1 / 101000 [  0%]  (Warmup) 
Chain 1 Iteration:   1001 / 101000 [  0%]  (Sampling) 
Chain 1 Iteration:  51000 / 101000 [ 50%]  (Sampling) 
Chain 1 Iteration: 101000 / 101000 [100%]  (Sampling) 
Chain 1 finished in 0.9 seconds.

We can now compare the approximation provided by Stan:

fit
 variable  mean median   sd  mad    q5   q95 rhat ess_bulk ess_tail
     lp__ -0.37  -0.12 0.64 0.14 -1.62 -0.02 1.00    19501    24505
     x     1.11   0.55 1.25 0.65  0.07  4.01 1.00    19501    24505

Tips and tricks

Interpreting the output

  • Just as with simple Monte Carlo and SNIS, the output of Stan/MCMC is a list of samples
  • In contrast to SNIS, the samples are equally weighted,
    • in other words, for MCMC we have \(W^{(m)}= 1/M\),
    • i.e., their structure is the same as samples from simple Monte Carlo!

To access the samples from a “fit” object when calling Stan, use:

xs = as.vector(fit$draws("x"))

As a sanity check, let’s make sure all the samples are greater than the observed lower bound of \(0.06\):

min(xs)
[1] 0.06001702

Question: denote the samples output from Stan by \(X^{(1)}, X^{(2)}, \dots, X^{(M)}\). What formula should you use to approximate \(\mathbb{E}[g(X) | Y = y]\) based on the Stan output?

  1. \(\sum_x p(x) g(x)\)
  2. \(\sum_x f(x) g(x)\)
  3. \(\frac{1}{M} \sum_{m=1}^M g(X^{(m)})\)
  4. \(\frac{1}{M} \sum_{m=1}^M f(X^{(m)}) g(X^{(m)})\)
  5. None of the above

As in simple Monte Carlo, use \(\frac{1}{M} \sum_{m=1}^M g(X^{(m)})\).

Plotting histograms

Since the samples are equally weighted, this makes it simpler to create plots, e.g., for a histogram:

hist(xs)

You can also use the bayesplot package which will be handy to create more complex plots from MCMC runs:

suppressPackageStartupMessages(require(bayesplot))
mcmc_hist(fit$draws(), par = c("x"))
`stat_bin()` using `bins = 30`. Pick better value with `binwidth`.