Stochastic gradient estimation

Outline

Topics

  • Stochastic gradient estimation problem
  • Push-out estimator (also known as “reparameterization trick”)
  • Subsampling

Rationale

To optimize our variational objective \(L\) using SGD, we need to construct an unbiased estimator of the gradient of \(L\).

We review one particularly effective method to do so, coming from the operations research literature, the push-out estimator Rubinstein, 1992. It is called the “reparameterization trick” in the machine learning literature.

Variational inference objective

Recall, our VI objective function is \[\begin{align*} L(\phi) &= \int q_\phi(x) \left[ \log q_\phi(x) - \log \gamma(x) \right] \mathrm{d}x \\ &= \mathbb{E}_\phi\left[ \log q_\phi(X) - \log \gamma(X) \right], \end{align*}\] where we use write a subscript \(\phi\) on \(\mathbb{E}_\phi[\cdot]\) to emphasize that the expectation is with respect to a distribution \(q_\phi\) that depends on \(\phi\).

Generic gradient estimation setup

To make notation cleaner, we abstract out the problem to:

Definition: the stochastic gradient estimation problem consists in finding an unbiased estimator for the gradient of \(\mathbb{E}_\phi[h(\phi, X)],\) where \(X \sim q_\phi\).

Example: for VI, \(h(\phi, x) = \log q_\phi(x) - \log \gamma(x)\).

Difficulty

Find the most serious error in the buggy argument below:

\[\begin{align*} \nabla \mathbb{E}_\phi[h(\phi, X)] &= \mathbb{E}_\phi[\nabla h(\phi, X)] \;\;\text{(interchange of $\nabla$ and $\mathbb{E}$)} \\ &\approx \underbrace{ \frac{1}{M} \sum_{m=1}^M \nabla h(\phi, X^{(m)})}_{\text{broken stochastic gradient estimator}} \;\;\text{(simple Monte Carlo)}. \end{align*}\]

Solution: reparameterization

Idea: move the parts of \(q\) that depend on \(\phi\) into \(h\) via reparametrization.

Example:

  • suppose \(\{q_\phi\}\) is a normal family, so \(\phi = (\mu, \sigma^2)\).
  • Note:
    • if \(S \sim \mathcal{N}(0, 1)\) is standard normal,
    • then \(\sigma S + \mu \sim \mathcal{N}(\mu, \sigma^2)\)
    • hence: \[\mathbb{E}_\phi[h(\phi, X)] = \mathbb{E}[h(\phi, \sigma S + \mu)].\]
  • Notice that on the right-hand side, the distribution with respect to which we take the expectation no longer depends on \(\phi\)!

General method: (“push-out estimator” or “reparameterization trick”)

  • if for all \(\phi\), \(X_\phi \sim q_\phi\)
  • …you can write \(X_\phi = f(S, \phi)\)
    • for some random variable \(S\)
    • and function \(f(s, \phi)\) (e.g., \(f(s, \phi) = \sigma s + \mu\) in the normal case),1
  • then:

\[\begin{align*} \nabla \mathbb{E}_\phi[h(\phi, X)] &= \nabla \mathbb{E}[h(\phi, f(S, \phi))] \\ &= \mathbb{E}[\nabla\{h(\phi, f(S, \phi))\}] \;\;\text{(assuming Leibniz integral rule applies)} \\ &\approx \frac{1}{M} \sum_{m=1}^M \nabla\{h(\phi, f(S^{(m)}, \phi))\}. \end{align*}\]

Finally, one typically use reverse mode autodiff to compute \(\nabla\{h(\phi, f(S^{(m)}, \phi))\}\).

Subsampling

  • Often (but not always), the function \(h\) be written as: \[h(\phi, x) = \sum_{i=1}^N h_i(\phi, x),\]
  • example: when the data is i.i.d. (recall we are in log space, \(h(\phi, x) = \log q_\phi(x) - \log \gamma(x)\))
    • \(N\) is then the number of data points
  • When \(N\) is large “sub-sampling” will tradeoff:
    • computationally cheaper gradient…
    • …at the cost of more noise (variance) and hence more SGD iterations \(t\) needed.
  • Idea: sub-sampling consists in
    • sampling one term (data point) \(I \sim {\mathrm{Unif}}\{1, 2, \dots, N\}\)
    • computing unbiased estimate \(\hat G_I\) for the random term
    • debiasing by returning \(N \hat G_I\)

Property: the subsampling estimator \(N \hat G_I\) is unbiased provided for each \(i\), \(\hat G_i\) is unbiased.

Proof:

\[\begin{align*} \mathbb{E}[N \hat G_I] &= N \mathbb{E}[ \mathbb{E}[\hat G_I|I]] \;\;\text{(law of total expectation)} \\ &= N \sum_{i=1}^N \frac{1}{N} \mathbb{E}[\hat G_i] \\ &= \sum_{i=1}^N \mathbb{E}[\hat G_i]. \end{align*}\]

Mini-batching: this idea can be extended to picking a small subset of points (typically, the maximum that can fit in the GPU memory).

References

See Mohamed et al, 2020, Monte Carlo Gradient Estimation in Machine Learning, JMLR.

Footnotes

  1. Typically, in order for the interchange of gradient and expectation to hold, \(f\) has to be differentiable everywhere (not just almost everywhere, see e.g., Stat 547C notes, section 6.8 for an example).↩︎