Rao-Blackwellization

Outline

Topics

  • Rao-Blackwellization: what and why
  • Implementing Rao-Blackwellization in Stan
  • Mathematical underpinnings

Rationale

Rao-Blackwellization is an important technique for two reasons:

  • It can speed up MCMC considerably.
  • In languages that do not support discrete latent variables (for example, Stan), this is the only way to implement certain models (e.g. mixture models, coming next)

Stan demo

  • We revisit the Chernobyl example.
  • This time we implement it in Stan differently to demonstrate Rao-Blackwellization
Code
set.seed(1)

# detection limit: value higher than that stay at the limit
limit = 1.1 

n_measurements = 10

true_mean = 5.6

# data, if we were able to observe perfectly
y = rexp(n_measurements, 1.0/true_mean)

# number of measurements higher than the detection limit
n_above_limit = sum(y >= limit)
n_below_limit = sum(y < limit)

# subset of the measurements that are below the limit:
data_below_limit = y[y < limit]

# measurements: those higher than the limit stay at the limit
measurements = ifelse(y < limit, y, limit)
suppressPackageStartupMessages(library(cmdstanr))
chernobyl_rao_blackwellized.stan
data {
  int<lower=0> n_above_limit;
  int<lower=0> n_below_limit;
  real<lower=0> limit;
  vector<upper=limit>[n_below_limit] data_below_limit;
}

parameters {
1  real<lower=0> rate; //
}

model {
  // prior
  rate ~ exponential(1.0/100);
  
  // likelihood
2  target += n_above_limit * exponential_lccdf(limit | rate); //
  data_below_limit ~ exponential(rate); 
}

generated quantities {
  real mean = 1.0/rate;
}
1
Notice that this time we are not including all these \(H_i\) above the detection limit! How is this possible?
2
What is that exponential_lccdf?!

Let us make sure it works empirically first:

mod = cmdstan_model("chernobyl_rao_blackwellized.stan")
fit = mod$sample(
  seed = 1,
  chains = 1,
  data = list(
            limit = limit,
            n_above_limit = n_above_limit, 
            n_below_limit = n_below_limit,
            data_below_limit = data_below_limit
          ),  
  output_dir = "stan_out",
  refresh = 20000,
  iter_sampling = 100000                   
)
Running MCMC with 1 chain...

Chain 1 Iteration:      1 / 101000 [  0%]  (Warmup) 
Chain 1 Iteration:   1001 / 101000 [  0%]  (Sampling) 
Chain 1 Iteration:  21000 / 101000 [ 20%]  (Sampling) 
Chain 1 Iteration:  41000 / 101000 [ 40%]  (Sampling) 
Chain 1 Iteration:  61000 / 101000 [ 60%]  (Sampling) 
Chain 1 Iteration:  81000 / 101000 [ 80%]  (Sampling) 
Chain 1 Iteration: 101000 / 101000 [100%]  (Sampling) 
Chain 1 finished in 0.5 seconds.
fit$summary(NULL, c("mean", "mcse_mean"))
# A tibble: 3 × 3
  variable   mean mcse_mean
  <chr>     <dbl>     <dbl>
1 lp__     -8.24    0.00395
2 rate      0.395   0.00100
3 mean      3.38    0.0150 

Compare: this to the result from the page on censoring, where we got:

  ...
Chain 1 finished in 1.6 seconds.
  ...
   variable               mean mcse_mean
  ...
10 mean                  3.41   0.0178  

Conclusion: Essentially the same result (i.e., within MCSE), but the new version is faster! How is this possible?

Mathematical underpinnings

  • Consider a simplified example where there is only one observation:

\[\begin{align*} X &\sim {\mathrm{Exp}}(1/100) \\ H &\sim {\mathrm{Exp}}(X) \\ C &= \mathbb{1}[H \ge L], \\ Y &= C L + (1 - C) H. \end{align*} \tag{1}\]

  • Suppose our one observation is censored (\(C = 1\)).
  • The first Stan model we implemented targets a distribution over both \(X\) and \(H\), \(\gamma(x, h) = p(x, h, y)\).
  • Key idea behind Rao-Blackwellization:
    • Reduce the problem to a target over \(x\) only, \(\gamma(x)\).
    • This is because we do not care much about \(h\): it is a nuisance variable.
    • But we would like to do so in such a way that the result of inference on \(x\) is the same with \(\gamma(x, h)\) and \(\gamma(x)\) (just faster with Rao-Blackwellization).

Question: how to define \(\gamma(x)\) from \(\gamma(x, h)\) so that the result on \(x\) stays the same?

  1. \(\max_x \gamma(x, h)\)
  2. \(\max_h \gamma(x, h)\)
  3. \(\int \gamma(x, h) \mathrm{d}x\)
  4. \(\int \gamma(x, h) \mathrm{d}h\)
  5. None of the above.

The correct answer is 4:

\[\begin{align*} \int \gamma(x, h) \mathrm{d}h &= \int f(x, h, y) \mathrm{d}h \\ &= f(x, y) = \gamma(x). \end{align*}\]

Question: compute \(\gamma(x)\) in our simplified example, Equation 1.

  1. \(f(x) F(L; x)\)
  2. \(f(x) (1 - F(L; x))\)
  3. \(f(x) F(x; L)\)
  4. \(f(x) (1 - F(x; L))\)
  5. None of the above.

Where \(F(h; x)\) is the cumulative distribution function of the likelihood given parameters \(x\).

Since \(y = L\), the detection limit,

\[\begin{align*} \gamma(x) &= \int \gamma(x, h) \mathrm{d}h \\ &= \int f(x, h, y) \mathrm{d}h \\ &= \int f(x) f(h | x) f(y | h) \mathrm{d}h \\ &= f(x) \int f(h | x) \mathbb{1}[L \le h] \mathrm{d}h \\ &= f(x) (1 - F(L; x)). \end{align*}\]

Note that in the Stan code above, \(1 - F(L; x)\) is implemented with exponential_lccdf(limit | rate) (“lccdf” stands for “log complement of cumulative distribution function”).