Rao-Blackwellization

Outline

Topics

  • Rao-Blackwellization: what and why
  • Implementing Rao-Blackwellization in Stan
  • Mathematical underpinnings

Rationale

Rao-Blackwellization is an important technique for two reasons:

  • It can speed-up MCMC considerably.
  • In languages that do not support discrete latent variables (for example, Stan), this is they only way to implement certain model (e.g. mixture models, coming next)

Stan demo

  • We revisit the Chernobyl example.
  • This time we implement it in Stan differently to demonstrate Rao-Blackwellization
Code
set.seed(1)

# detection limit: value higher than that stay at the limit
limit = 1.1 

n_measurements = 10

true_mean = 5.6

# data, if we were able to observe perfectly
y = rexp(n_measurements, 1.0/true_mean)

# number of measurements higher than the detection limit
n_above_limit = sum(y >= limit)
n_below_limit = sum(y < limit)

# subset of the measurements that are below the limit:
data_below_limit = y[y < limit]

# measurements: those higher than the limit stay at the limit
measurements = ifelse(y < limit, y, limit)
suppressPackageStartupMessages(require(rstan))
data {
  int<lower=0> n_above_limit;
  int<lower=0> n_below_limit;
  real<lower=0> limit;
  vector<upper=limit>[n_below_limit] data_below_limit;
}

parameters {
1  real<lower=0> rate;
}

model {
  // prior
  rate ~ exponential(1.0/100);
  
  // likelihood
2  target += n_above_limit * exponential_lccdf(limit | rate);
  data_below_limit ~ exponential(rate); 
}

generated quantities {
  real mean = 1.0/rate;
}
1
Notice that this time we are not including all these \(H_i\) above the detection limit! How is this possible?
2
What is that exponential_lccdf?!

Let us make sure it works empiricially first:

fit = sampling(
  chernobyl_rao_blackwellized,
  seed = 1,
  chains = 1,
  data = list(
            limit = limit,
            n_above_limit = n_above_limit, 
            n_below_limit = n_below_limit,
            data_below_limit = data_below_limit
          ),       
  iter = 100000                   
)

SAMPLING FOR MODEL 'anon_model' NOW (CHAIN 1).
Chain 1: 
Chain 1: Gradient evaluation took 1.7e-05 seconds
Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 0.17 seconds.
Chain 1: Adjust your expectations accordingly!
Chain 1: 
Chain 1: 
Chain 1: Iteration:     1 / 100000 [  0%]  (Warmup)
Chain 1: Iteration: 10000 / 100000 [ 10%]  (Warmup)
Chain 1: Iteration: 20000 / 100000 [ 20%]  (Warmup)
Chain 1: Iteration: 30000 / 100000 [ 30%]  (Warmup)
Chain 1: Iteration: 40000 / 100000 [ 40%]  (Warmup)
Chain 1: Iteration: 50000 / 100000 [ 50%]  (Warmup)
Chain 1: Iteration: 50001 / 100000 [ 50%]  (Sampling)
Chain 1: Iteration: 60000 / 100000 [ 60%]  (Sampling)
Chain 1: Iteration: 70000 / 100000 [ 70%]  (Sampling)
Chain 1: Iteration: 80000 / 100000 [ 80%]  (Sampling)
Chain 1: Iteration: 90000 / 100000 [ 90%]  (Sampling)
Chain 1: Iteration: 100000 / 100000 [100%]  (Sampling)
Chain 1: 
Chain 1:  Elapsed Time: 0.154 seconds (Warm-up)
Chain 1:                0.152 seconds (Sampling)
Chain 1:                0.306 seconds (Total)
Chain 1: 
fit
Inference for Stan model: anon_model.
1 chains, each with iter=1e+05; warmup=50000; thin=1; 
post-warmup draws per chain=50000, total post-warmup draws=50000.

      mean se_mean   sd   2.5%   25%   50%   75% 97.5% n_eff Rhat
rate  0.39    0.00 0.20   0.11  0.25  0.36  0.50  0.86 17497    1
mean  3.39    0.02 2.35   1.16  1.99  2.78  4.02  9.25 13206    1
lp__ -8.24    0.01 0.73 -10.30 -8.40 -7.96 -7.77 -7.72 18171    1

Samples were drawn using NUTS(diag_e) at Tue Mar 19 07:17:10 2024.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

Compare: this to the result from the page on censoring, where we got:

  ...
1.258 seconds (Total)
  ...
mean  3.35    0.02 2.43   1.16   1.97   2.74   3.96   9.20
  ...

Conclusion: Essentially the same result (i.e., within MCSE), but about 4x faster! How is this possible?

Mathematical underpinnings

  • Consider a simplified example where there is only one observation:

\[\begin{align*} X &\sim {\mathrm{Exp}}(1/100) \\ H &\sim {\mathrm{Exp}}(X) \\ C &= \mathbb{1}[H \ge L], \\ Y &= C L + (1 - C) H. \end{align*} \tag{1}\]

  • Suppose our one observation is censored (\(C = 1\)).
  • The first Stan model we implemented targets a distribution over both \(X\) and \(H\), \(\gamma(x, h) = p(x, h, y)\).
  • Key idea behind Rao-Blackwellization:
    • Reduce the problem to a target over \(x\) only, \(\gamma(x)\).
    • This is because we do not care much about \(h\): it is a nuisance variable.
    • But we would like to do so in such as way that the result of inference on \(x\) is the same with \(\gamma(x, h)\) and \(\gamma(x)\) (just faster with Rao-Blackwellization).

Question: how to define \(\gamma(x)\) from \(\gamma(x, h)\) so that the result on \(x\) stay the same?

  1. \(\max_x \gamma(x, h)\)
  2. \(\max_h \gamma(x, h)\)
  3. \(\int \gamma(x, h) \mathrm{d}x\)
  4. \(\int \gamma(x, h) \mathrm{d}h\)
  5. None of the above.

The correct answer is 4:

\[\begin{align*} \int \gamma(x, h) \mathrm{d}h &= \int f(x, h, y) \mathrm{d}h \\ &= f(x, y) = \gamma(x). \end{align*}\]

Question: compute \(\gamma(x)\) in our simplified example, Equation 1.

  1. \(f(x) F(L; x)\)
  2. \(f(x) (1 - F(L; x))\)
  3. \(f(x) F(x; L)\)
  4. \(f(x) (1 - F(x; L))\)
  5. None of the above.

Where \(F(h; x)\) is the cumulative distribution function of the likelihood given parameters \(x\).

Since \(y = L\), the detection limit,

\[\begin{align*} \gamma(x) &= \int \gamma(x, h) \mathrm{d}h \\ &= \int f(x, h, y) \mathrm{d}h \\ &= \int f(x) f(h | x) f(y | h) \mathrm{d}h \\ &= f(x) \int f(h | x) \mathbb{1}[L \le h] \mathrm{d}h \\ &= f(x) (1 - F(L; x)). \end{align*}\]

Note that in the Stan code above, \(1 - F(L; x)\) is implemented with exponential_lccdf(limit | rate) (“lccdf” stands for “log complement of cumulative distribution function”).