Stochastic gradient estimation
Outline
Topics
- Stochastic gradient estimation problem
- Push-out estimator (also known as “reparameterization trick”)
- Subsampling
Rationale
To optimize our variational objective \(L\) using SGD, we need to construct an unbiased estimator of the gradient of \(L\).
We review one particularly effective method to do so, coming from the operations research literature, the push-out estimator Rubinstein, 1992. It is called the “reparameterization trick” in the machine learning literature.
Variational inference objective
Recall, our VI objective function is \[\begin{align*} L(\phi) &= \int q_\phi(x) \left[ \log q_\phi(x) - \log \gamma(x) \right] \mathrm{d}x \\ &= \mathbb{E}_\phi\left[ \log q_\phi(X) - \log \gamma(X) \right], \end{align*}\] where we use write a subscript \(\phi\) on \(\mathbb{E}_\phi[\cdot]\) to emphasize that the expectation is with respect to a distribution \(q_\phi\) that depends on \(\phi\).
Generic gradient estimation setup
To make notation cleaner, we abstract out the problem to:
Definition: the stochastic gradient estimation problem consists in finding an unbiased estimator for the gradient of \(\mathbb{E}_\phi[h(\phi, X)],\) where \(X \sim q_\phi\).
Example: for VI, \(h(\phi, x) = \log q_\phi(x) - \log \gamma(x)\).
Difficulty
Find the most serious error in the buggy argument below:
\[\begin{align*} \nabla \mathbb{E}_\phi[h(\phi, X)] &= \mathbb{E}_\phi[\nabla h(\phi, X)] \;\;\text{(interchange of $\nabla$ and $\mathbb{E}$)} \\ &\approx \underbrace{ \frac{1}{M} \sum_{m=1}^M \nabla h(\phi, X^{(m)})}_{\text{broken stochastic gradient estimator}} \;\;\text{(simple Monte Carlo)}. \end{align*}\]
- Product rule on \(q_\phi(x) \log q_\phi(x)\) leads to a missing term.
- One can never interchange \(\nabla\) and \(\mathbb{E}\).
- Simple Monte Carlo cannot be applied because \(q\) is intractable.
- Simple Monte Carlo cannot be applied because of some other reason.
- None of the above.
Correct answer is 1. Recall the product rule is \((u v)' = u' v + u v'\).
Regarding 2: while the interchange needs to be justified mathematically, it does hold under certain conditions for example when the function \(h\) is bounded and differentiable everywhere, see 547C lecture notes, Section 6.8 for more details.
Here is the details of where the missing term comes from (assuming interchange is justified):
\[\begin{align*} \nabla \mathbb{E}_\phi[ h(\phi, X)] &= \nabla \int q_\phi(x) h(\phi, x) \mathrm{d}x \\ &= \int \nabla[q_\phi(x) h(\phi, x)] \mathrm{d}x \\ &= \underbrace{\int q_\phi(x) \nabla h(\phi, x) \mathrm{d}x}_{\mathbb{E}_\phi[\nabla h(\phi, X)]} + \underbrace{\int h(\phi, x) \nabla q_\phi(x) \mathrm{d}x.}_\text{term missing in the above "broken estimator"} \end{align*}\]
Note the missing term is “annoying” because it no longer has the \(q_\phi(x)\) factor making it less straightforward to write as an expectation and hence approximated using simple Monte Carlo.
Solution: reparameterization
Idea: move the parts of \(q\) that depend on \(\phi\) into \(h\) via reparametrization.
Example:
- suppose \(\{q_\phi\}\) is a normal family, so \(\phi = (\mu, \sigma^2)\).
- Note:
- if \(S \sim \mathcal{N}(0, 1)\) is standard normal,
- then \(\sigma S + \mu \sim \mathcal{N}(\mu, \sigma^2)\)
- hence: \[\mathbb{E}_\phi[h(\phi, X)] = \mathbb{E}[h(\phi, \sigma S + \mu)].\]
- Notice that on the right-hand side, the distribution with respect to which we take the expectation no longer depends on \(\phi\)!
General method: (“push-out estimator” or “reparameterization trick”)
- if for all \(\phi\), \(X_\phi \sim q_\phi\)…
- …you can write \(X_\phi = f(S, \phi)\)
- for some random variable \(S\)
- and function \(f(s, \phi)\) (e.g., \(f(s, \phi) = \sigma s + \mu\) in the normal case),1
- for some random variable \(S\)
- then:
\[\begin{align*} \nabla \mathbb{E}_\phi[h(\phi, X)] &= \nabla \mathbb{E}[h(\phi, f(S, \phi))] \\ &= \mathbb{E}[\nabla\{h(\phi, f(S, \phi))\}] \;\;\text{(assuming Leibniz integral rule applies)} \\ &\approx \frac{1}{M} \sum_{m=1}^M \nabla\{h(\phi, f(S^{(m)}, \phi))\}. \end{align*}\]
Finally, one typically use reverse mode autodiff to compute \(\nabla\{h(\phi, f(S^{(m)}, \phi))\}\).
Subsampling
- Often (but not always), the function \(h\) be written as: \[h(\phi, x) = \sum_{i=1}^N h_i(\phi, x),\]
- example: when the data is i.i.d. (recall we are in log space, \(h(\phi, x) = \log q_\phi(x) - \log \gamma(x)\))
- \(N\) is then the number of data points
- When \(N\) is large “sub-sampling” will tradeoff:
- computationally cheaper gradient…
- …at the cost of more noise (variance) and hence more SGD iterations \(t\) needed.
- Idea: sub-sampling consists in
- sampling one term (data point) \(I \sim {\mathrm{Unif}}\{1, 2, \dots, N\}\)
- computing unbiased estimate \(\hat G_I\) for the random term
- debiasing by returning \(N \hat G_I\)
Property: the subsampling estimator \(N \hat G_I\) is unbiased provided for each \(i\), \(\hat G_i\) is unbiased.
Proof:
\[\begin{align*} \mathbb{E}[N \hat G_I] &= N \mathbb{E}[ \mathbb{E}[\hat G_I|I]] \;\;\text{(law of total expectation)} \\ &= N \sum_{i=1}^N \frac{1}{N} \mathbb{E}[\hat G_i] \\ &= \sum_{i=1}^N \mathbb{E}[\hat G_i]. \end{align*}\]
Mini-batching: this idea can be extended to picking a small subset of points (typically, the maximum that can fit in the GPU memory).
References
See Mohamed et al, 2020, Monte Carlo Gradient Estimation in Machine Learning, JMLR.
Footnotes
Typically, in order for the interchange of gradient and expectation to hold, \(f\) has to be differentiable everywhere (not just almost everywhere, see e.g., Stat 547C notes, section 6.8 for an example).↩︎