Chapter II: Background
2.3 Variational Inference
2.3 Variational Inference
Algorithm 1Variational Expectation Maximization (EM)
1: Input: model ππ(x,z), data examplesx(1), . . . ,x(π) βΌ πdata(x)
2: whileπnot convergeddo
3: for x=x(1), . . . ,x(π) do
4: π(z|x) βarg maxπL (x;π, π) β²inference (E-step)
5: end for
6: πβ arg maxπ 1 π
Γπ
π=1L (x(π);π, π) β²learning (M-step)
7: end while
Rearranging terms in Eq. 2.20, we have
logππ(x) =L (x;π, π) +π·KL(π(z|x) ||ππ(z|x)). (2.23) Because KL divergence is non-negative, we see that L (x;π, π) β€ logππ(x), with equality when π(z|x) = ππ(z|x). As the LHS of Eq. 2.23 does not depend on π(z|x), maximizing L (x;π, π) w.r.t.πimplicitly minimizes π·KL(π(z|x) ||ππ(z|x)) w.r.t.π. Together, these statements imply that maximizingL (x;π, π)w.r.t.πtightens the lower bound on logππ(x). With this tightened lower bound, we can then maximize L (x;π, π) w.r.t.π. This alternating optimization process is referred to as the variational expectation maximization (EM) algorithm (Dempster, Laird, and Rubin, 1977; Neal and G. E. Hinton, 1998) (Algorithm 1), consisting of approximate inference (E-step) and learning (M-step). While Algorithm 1 iterates over the entire data set between parameter updates, one can instead perform inference for a mini-batch of data examples and perform stochastic gradient ascent onπ(Hoffman et al., 2013). This version of the algorithm, referred to asstochastic variational inference (SVI), is used in practice to train deep latent variable models.
Interpreting the ELBO
The primary motivation for variational inference is in providing a tractable lower bound for model training. In the process, however, we are left with an approxi- mate posterior distribution, π(z|x). To gain insight into this distribution and the approximate inference procedure, we can rewrite the ELBO as
L (x;π, π) =EzβΌπ(z|x) [logππ(x|z) +logππ(z) βlogπ(z|x)]. (2.24) Each of the terms in the objective guides the optimization ofπ(z|x). The first term quantifies the agreement with the data, i.e., reconstruction: samplingzfromπ, we want to assign high log-probability to the observation,x. The second term quantifies the agreement with the prior prediction: samplingzfromπ, we want to have high
7 pβ(z)
<latexit sha1_base64="IEFcZ6TdJDNiYnM2MXVVhnD1rxA=">AAACjHicdVHLSiNBFK20j3HiKzrMyk1hEHQTulUYQQRxQFwqGBWSpqmu3E6KVHUVVbfF2ORj3OoXzd9MdczCGL1w4XDu63BuaqRwGIb/asHC4tLyj5Wf9dW19Y3Nxtb2ndOF5dDmWmr7kDIHUuTQRoESHowFplIJ9+nwb1W/fwTrhM5vcWQgVqyfi0xwhp5KGr9N0sUBIKP7XcVwkGbl8/ggaTTDVjgJOg+iKWiSaVwnW7Wk29O8UJAjl8y5ThQajEtmUXAJ43q3cGAYH7I+dDzMmQIXlxP9Y7rnmR7NtPWZI52wHydKppwbqdR3Vhrd51pFflXrFJidxKXITYGQ8/dDWSEpalqZQXvCAkc58oBxK7xWygfMMo7esplNt1FcVuKqNTPnUzWu79GPTCXDoHqa7WOyr/2FgfqGFnx+wDgovKu65w30L4k+P2Ae3B22oqPW4c1x8/xi+pwVskN2yT6JyB9yTq7INWkTTkryQl7JW7ARHAenwdl7a1CbzvwiMxFc/gdUOMp5</latexit>
q(z|x)
<latexit sha1_base64="GKYv3YBzSW+vcCnXm4SsDmmxo+c=">AAACknicdVHLTgIxFC3jG1+g7tg0EhLckBk00bhC3bhwoQkgCUxIp3SgoZ2ObceA4yz8Grf6Of6NHcSEh96kyem59/aenuuFjCpt218Za2V1bX1jcyu7vbO7t5/LHzSViCQmDSyYkC0PKcJoQBqaakZaoSSIe4w8esObNP/4TKSiIqjrcUhcjvoB9SlG2lDdXOEJljsc6YHnxy8JfIW/l1Fy0s0V7Yo9CbgMnCkogmncd/OZbqcncMRJoDFDSrUdO9RujKSmmJEk24kUCREeoj5pGxggTpQbT36RwJJhetAX0pxAwwk72xEjrtSYe6Yy1agWcyn5V64daf/CjWkQRpoE+GeQHzGoBUwtgT0qCdZsbADCkhqtEA+QRFgb4+ZeqjtunIpLn5kb7/EkW4KzTCoj1Hw0X4dYX5gJA/4PTfFyQ6hIZFwVPWOgWYmzuIBl0KxWnNNK9eGsWLueLmcTFMAxKAMHnIMauAX3oAEweAPv4AN8WkfWpXVl3fyUWplpzyGYC+vuG+HlzN8=</latexit>
pβ(x|z)
<latexit sha1_base64="60DGexcpR/Fv1prZWkg4K8IbvkM=">AAACmnicdVFNTxsxEHW2LdC0fLXH9mARIdFLtBuQ4BjBhaoXkBJASlaR15lNLOy1Zc+ihGV/AL+GK/0p/Tf1hiARAiNZen7zxjN+kxgpHIbhv1rw4eOnldW1z/UvX9c3Nre2v104nVsOXa6ltlcJcyBFBl0UKOHKWGAqkXCZXJ9U+csbsE7orINTA7Fio0ykgjP01GCrQc2gj2NARvf6iuE4SYtJSe/o8+W2/OVVYTOcBV0G0Rw0yDzOBtu1QX+oea4gQy6Zc70oNBgXzKLgEsp6P3dgGL9mI+h5mDEFLi5mvynprmeGNNXWnwzpjH1ZUTDl3FQlXlnN6F7nKvKtXC/H9CguRGZyhIw/NUpzSVHTyho6FBY4yqkHjFvhZ6V8zCzj6A1ceKkTxUU1XPXMQvtElfVd+pKpxjCoJos6Jkfadxird2jBlwuMg9y7qofeQL+S6PUClsFFqxntN1vnB4328Xw5a+QH2SF7JCKHpE1OyRnpEk7uyQN5JH+Dn8Fx8Dv48yQNavOa72Qhgs5/fj/QHw==</latexit>
Eq[logpβ(x|z)]
<latexit sha1_base64="ThBf6Fct1myVHSay2XFe4My8dic=">AAACunicdVFLbxMxEHaWVwmvFI5cLKJK5RLtFiQ4cKiokDgWqWkrZVcr25ndtWKvjT2LGpb9QfwarvTf4E2C1DQwkuXP33zz8Ay3SnqM4+tBdOfuvfsP9h4OHz1+8vTZaP/5uTeNEzAVRhl3yZkHJWuYokQFl9YB01zBBV+c9P6Lb+C8NPUZLi1kmpW1LKRgGKh8dJJqhhXn7acu/0pTBQXOwmVKavMUK0BGD9eSor3q6A/69/G9e01TJ8sKs3w0jifxyuguSDZgTDZ2mu8P8nRuRKOhRqGY97Mktpi1zKEUCrph2niwTCxYCbMAa6bBZ+3qtx09CMycFsaFUyNdsTcjWqa9X2oelH2v/ravJ//lmzVYvM9aWdsGoRbrQkWjKBraj47OpQOBahkAE06GXqmomGMCw4C3Mp0lWds316fZKs91NzygN5m+DYv6alvHVGlChUr/h5ZiN8B6aMJUzTwMMKwkub2AXXB+NEneTI6+vB0ff9wsZ4+8JK/IIUnIO3JMPpNTMiWC/CS/yG9yHX2IeCSjxVoaDTYxL8iWRfgHlrzdbQ==</latexit>
Eq
ο£Ώ logq(z|x)
pβ(z)
<latexit sha1_base64="bSdQSp0iZL3ytr3Aon9aL1YVVZc=">AAAC0nicdVHLatwwFNW4r3T6mqTLbkSHQLLoYCeFdhkaAl2mMJMEbGNkjWyLSJYiXZdMVC9Ktv2pfka/oNv2DyrPTGEm014QHJ177oNzcy24hTD80Qvu3X/w8NHW4/6Tp8+evxhs75xZ1RjKJlQJZS5yYpngNZsAB8EutGFE5oKd55fHXf78MzOWq3oMM81SScqaF5wS8FQ2iN8kkkCV5+6kza5wIlgBcSJUiZPCEOqu8N5CULibFn/Bfz/X7X7rdJZAxYCsavZbnBheVpDibDAMR+E88CaIlmCIlnGabfeyZKpoI1kNVBBr4yjUkDpigFPB2n7SWKYJvSQliz2siWQ2dXMXWrzrmSkulPGvBjxnVysckdbOZO6V3bb2bq4j/5WLGyjep47XugFW08WgohEYFO4sxVNuGAUx84BQw/2umFbEuwfe+LVO4yh13XJdm7XxuWz7u3iV6dbQIK/XdUSUyk+o5H9oTjcLtGWNd1VNvYH+JNHdA2yCs4NRdDg6+PR2ePRheZwt9Aq9RnsoQu/QEfqITtEEUfQd/US/0O9gHNwEX4PbhTToLWteorUIvv0BDFnnlA==</latexit>
xβ pdata(x)
<latexit sha1_base64="Min9X1kSrlMsgf0t/tlM/bkvIlw=">AAACo3icdVFNa9tAEF2rX4nTNk57zGWpMSRQjJQE2mNoL4Vc0mAnAUuI1WpkL9nVLrujYiP0J/prek3+Rf9NVo4hcdwOLDzevJl5O5MZKRyG4d9O8OLlq9dvtra7O2/fvd/t7X24dLqyHMZcS22vM+ZAihLGKFDCtbHAVCbhKrv53uavfoF1QpcjXBhIFJuWohCcoafS3udYMZxlRT1vaOyEoiaNEeZoVZ0zZA09eBQcpr1+OAyXQTdBtAJ9sorzdK+TxrnmlYISuWTOTaLQYFIzi4JLaLpx5cAwfsOmMPGwZApcUi+/1dCBZ3JaaOtfiXTJPq2omXJuoTKvbD2657mW/FduUmHxNalFaSqEkj8MKipJUdN2RzQXFjjKhQeMW+G9Uj5jlnH0m1zrNIqSujXXtlkbn6mmO6BPmdaGQTVf1zE51X7CTP2HFnyzwDio/FZ17hfoTxI9P8AmuDwaRsfDo58n/dNvq+NskX3yiRyQiHwhp+QHOSdjwslv8ofckrtgEJwFF8HoQRp0VjUfyVoEyT0JRtSg</latexit>
latent space
observed space
(a) Computation Graph
20
structured models
hierarchical dynamical
(b) Hierarchical
20
structured models
hierarchical (c) Sequentialdynamical
Figure 2.4: ELBO Computation Graphs. (a)Basic computation graph for varia- tional inference. Outlined circles denote distributions. Smaller red circles denote terms in the ELBO objective. Arrows, again, denote conditional dependencies. This notation can be used to express (b) hierarchical and (c) sequential models with various model dependencies.
log-probability under the prior. The final term is theentropyofπ(z|x), quantifying the spread or uncertainty of the distribution. Maximizing this term encouragesπ(z|x) to remain maximally uncertain, all else being equal (Jaynes, 1957). Often, the final two terms are combined, as in Eq. 2.22, withπ·KL(π(z|x) ||ππ(z)) interpreted as a form of βregularization,β penalizing deviations ofπ(z|x)from ππ(z).
This KL divergence term has been the focus of recent investigations. In particular, one can interpret the ELBO as a Lagrangian, with Lagrange multiplier π½=1:
L (x;π, π) =EzβΌπ(z|x) [logππ(x|z)] β π½ π·KL(π(z|x) ||ππ(z)). (2.25) Thus, the ELBO can be seen as expressing a constrained optimization problem, with the regularizing constraint onπ·KL(π(z|x) ||ππ(z)) mediated by π½. In models with flexible, e.g., autoregressive, conditional likelihoods, the model may not utilizez, in which case the KL term results in a local maximum atπ(z|x) = ππ(z). This is the so-called βdying unitsβ problem (Bowman et al., 2016), in which the latent variables are effectively unused. A common approach is to anneal π½from 0β 1 during the course of training (Bowman et al., 2016; SΓΈnderby et al., 2016), though alternative schemes exist (Kingma, Salimans, et al., 2016; X. Chen et al., 2017). We can also dynamically adjust π½to maintain a particular KL or conditional likelihood constraint (Rezende and Viola, 2018). And by setting π½ > 1 we can over-regularize π(z|x), and with an independent prior, this can yield more disentangled latent variables (Higgins et al., 2017; Burgess et al., 2018). In these latter cases, where π½ β 1, L (x;π, π)is no longer a valid lower bound on logππ(x). Alemi et al., 2018 provide
an information-theoretic framework to make sense of these ideas, identifying the connection between the two terms in the ELBO (Eq. 2.25) and the concepts of distortionandraterespectively. In short, adjusting π½allows one to target particular rate-distortion trade-offs, which, in turn, result in bounds on the mutual information between π and π. This perspective allows one to consider latent variable models that achieve varying compression rates in complex, natural data domains, such as images (BallΓ©, Laparra, and Simoncelli, 2017) and video (Lombardo et al., 2019).
As with autoregressive models (Figure 2.3), we can represent latent variable models and the ELBO objective as a computation graph, again breaking the dependency structure graphs from Figure 2.2 into separate distributions and terms in the objective.
In Figure 2.4, we illustrate examples of these graphs. Each variable contains a red circle, denoting a term in the ELBO objective. In comparison with the fully-observed autoregressive model, we now have an additional objective term for the latent variable, corresponding to the KL divergence. This graphical representation also allows us to visualize the variational objective for more complex hierarchical and sequential models (Figures 2.4b & 2.4c).
Inference Optimization
In deriving variational inference and variational EM, we described inference opti- mization, the E-step, as the process of maximizingL (x;π, π) w.r.t. the distribution π(z|x). The fact that probability distributions are functions makes L (x;π, π) a functionaland makes inference optimization a variational calculus problem. One could solve this optimization problem using a variety of techniques, ranging from gradient-free, e.g., evolution strategies, to first-order gradient-based, e.g., black- box variational inference (Ranganath, Gerrish, and Blei, 2014), to second-order gradient-based, e.g., the Laplace approximation (Friston et al., 2007; Park, C. Kim, and G. Kim, 2019). Despite this range, first-order gradient-based techniques are most common in practice; when combined with amortization (see below), they can result in exceedingly efficient optimization. For this reason, we exclusively focus on first-order techniques in the following chapters, which we now describe in more detail.
With a parametricπ(z|x), first-order inference optimization entails calculating the gradients ofL (x;π, π) w.r.t. the distribution parameters ofπ, which we refer to as Ξ». For instance, ifπ(z|x) =N (z;Β΅π,diag(Οπ2)), i.e.,Ξ» =
Β΅π,Οπ
, then we must calculateβΒ΅
πL andβΟ
πL. We can consider the ELBO as an example of a stochastic
8
(a) Inference
10
z<latexit sha1_base64="JJBRHMAw1+FJL17Y9ikHFHSRDGA=">AAACf3icdVHLTgIxFC3jC/EFunTTSIiuyIyaqDuiG5eY8IowIZ1OBxr6mLQdI074C7f6X/6NHWDBgN6kycm5j3N6bxAzqo3r/hScre2d3b3ifung8Oj4pFw57WiZKEzaWDKpegHShFFB2oYaRnqxIogHjHSDyVOW774RpakULTONic/RSNCIYmQs9TrgyIyDKP2YDctVt+7OA24CbwmqYBnNYaUwHIQSJ5wIgxnSuu+5sfFTpAzFjMxKg0STGOEJGpG+hQJxov10bnkGa5YJYSSVfcLAObvakSKu9ZQHtjKzqNdzGflXrp+Y6N5PqYgTQwReCEUJg0bC7P8wpIpgw6YWIKyo9QrxGCmEjd1SblLL89PMXDYmJx/wWakGV5nMRmz4e74OsZG0CmP+D03xZkOsSWK3KkO7QHsSb/0Am6BzXfdu6tcvt9XG4/I4RXAOLsAV8MAdaIBn0ARtgIEAn+ALfDsF59KpO+6i1Ckse85ALpyHX0lCxig=</latexit>
<latexit sha1_base64="HlSpur9H0Rx2BbfN6EKhfq3sCbs=">AAACgXicdVHLSgMxFE2nPmp9tbp0EywF3ZSZKii4KbpxWaEv6Awlk0nb0GQSkoxYhv6GW/0t/8ZMO4s+9ELgcO65uYdzQ8moNq77U3CKe/sHh6Wj8vHJ6dl5pXrR0yJRmHSxYEINQqQJozHpGmoYGUhFEA8Z6Yezl6zffydKUxF3zFySgKNJTMcUI2Mp3w956jMrj9BiVKm5DXdZcBd4OaiBvNqjamHkRwInnMQGM6T10HOlCVKkDMWMLMp+oolEeIYmZGhhjDjRQbo0vYB1y0RwLJR9sYFLdn0iRVzrOQ+tkiMz1du9jPyrN0zM+DFIaSwTQ2K8WjROGDQCZgnAiCqCDZtbgLCi1ivEU6QQNjanjZ86XpBm5rJvNtaHfFGuw3UmsyEN/9jUITYRdsOU/0NTvDsgNUlsqiKyAdqTeNsH2AW9ZsO7azTf7mut5/w4JXAFrsEN8MADaIFX0AZdgIEEn+ALfDtF59ZxneZK6hTymUuwUc7TL+fPxtw=</latexit>
logq
<latexit sha1_base64="Dq+cAEIkibq9J0Jy0Skhk1KMXas=">AAACjHicdVFdS8MwFM3q15w6p+KTL8Ex8Gm0U1AQYSiIjxP2BWspaZptYUlTk1QcpT/GV/1F/hvTbQ/70AuBw7n35B7ODWJGlbbtn4K1tb2zu1fcLx0cHpWPKyenXSUSiUkHCyZkP0CKMBqRjqaakX4sCeIBI71g8pT3e+9EKiqitp7GxONoFNEhxUgbyq+cu0yM4JufugFPXWaEIcoyv1K16/as4CZwFqAKFtXyTwq+GwqccBJpzJBSA8eOtZciqSlmJCu5iSIxwhM0IgMDI8SJ8tKZ/wzWDBPCoZDmRRrO2GVFirhSUx6YSY70WK33cvKv3iDRwzsvpVGcaBLh+aJhwqAWMA8DhlQSrNnUAIQlNV4hHiOJsDaRrfzUdrw0N5d/s7I+4FmpBpeZ3Eas+cfqHGIjYTaM+T80xZuCWJHEpCpCE6A5ibN+gE3QbdSd63rj9abafFwcpwguwCW4Ag64BU3wAlqgAzBIwSf4At9W2bqx7q2H+ahVWGjOwEpZz79GS8rt</latexit>
β€
<latexit sha1_base64="Nkk3T+niU2cCEWunvav3KuI5f5w=">AAACdnicdVHLSgMxFE3HV62vVpeCBEtRXJSZKuiy6MZlC31BO5RMetuGJpMhyYhl6Be41Y/zT1yaabvoQy8EDuc+zsm9QcSZNq77nXF2dvf2D7KHuaPjk9OzfOG8pWWsKDSp5FJ1AqKBsxCahhkOnUgBEQGHdjB5SfPtN1CaybBhphH4goxCNmSUGEvV7/r5olt254G3gbcERbSMWr+Q6fcGksYCQkM50brruZHxE6IMoxxmuV6sISJ0QkbQtTAkArSfzJ3OcMkyAzyUyr7Q4Dm72pEQofVUBLZSEDPWm7mU/CvXjc3wyU9YGMUGQroQGsYcG4nTb+MBU0ANn1pAqGLWK6Zjogg1djlrkxqen6Tm0jFr8oGY5Up4lUltREa8r9cRPpJWYSz+oRndbog0xHarcmAXaE/ibR5gG7QqZe++XKk/FKvPy+Nk0SW6RrfIQ4+oil5RDTURRYA+0Cf6yvw4V07JuVmUOpllzwVaC8f9BZZ5wrc=</latexit>(b) Score Function
9
z<latexit sha1_base64="JJBRHMAw1+FJL17Y9ikHFHSRDGA=">AAACf3icdVHLTgIxFC3jC/EFunTTSIiuyIyaqDuiG5eY8IowIZ1OBxr6mLQdI074C7f6X/6NHWDBgN6kycm5j3N6bxAzqo3r/hScre2d3b3ifung8Oj4pFw57WiZKEzaWDKpegHShFFB2oYaRnqxIogHjHSDyVOW774RpakULTONic/RSNCIYmQs9TrgyIyDKP2YDctVt+7OA24CbwmqYBnNYaUwHIQSJ5wIgxnSuu+5sfFTpAzFjMxKg0STGOEJGpG+hQJxov10bnkGa5YJYSSVfcLAObvakSKu9ZQHtjKzqNdzGflXrp+Y6N5PqYgTQwReCEUJg0bC7P8wpIpgw6YWIKyo9QrxGCmEjd1SblLL89PMXDYmJx/wWakGV5nMRmz4e74OsZG0CmP+D03xZkOsSWK3KkO7QHsSb/0Am6BzXfdu6tcvt9XG4/I4RXAOLsAV8MAdaIBn0ARtgIEAn+ALfDsF59KpO+6i1Ckse85ALpyHX0lCxig=</latexit>
β
<latexit sha1_base64="mlLz5PRn6k/2I/uXqfF1v+kMzi4=">AAACgnicdVFNS0JBFB1fVmZfWss2QyJEC3kvg1q0kNq0NMgK9CHzxqsOzVcz8yJ5+Dva1s/q3zRPXWTWhQuHc78O5yaaM+vC8KsQrBXXNzZLW+Xtnd29/Ur14MGq1FDoUMWVeUqIBc4kdBxzHJ60ASISDo/J801ef3wFY5mS926iIRZkJNmQUeI8FfcSkfVAW8aVnPYrtbARzgKvgmgBamgR7X610O8NFE0FSEc5sbYbhdrFGTGOUQ7Tci+1oAl9JiPoeiiJABtnM9VTXPfMAA+V8SkdnrE/JzIirJ2IxHcK4sb2dy0n/6p1Uze8jDMmdepA0vmhYcqxUzi3AA+YAer4xANCDfNaMR0TQ6jzRi1tuo/iLBeXr1k6n4hpuY5/MrkM7cTbch/hI+UvjMU/NKOrA9pC6l1VA2+gf0n0+wGr4OGsETUbZ3fntdb14jkldISO0QmK0AVqoVvURh1E0Qt6Rx/oMygGp0EUNOetQWExc4iWIrj6Bk+fx38=</latexit> <latexit sha1_base64="HlSpur9H0Rx2BbfN6EKhfq3sCbs=">AAACgXicdVHLSgMxFE2nPmp9tbp0EywF3ZSZKii4KbpxWaEv6Awlk0nb0GQSkoxYhv6GW/0t/8ZMO4s+9ELgcO65uYdzQ8moNq77U3CKe/sHh6Wj8vHJ6dl5pXrR0yJRmHSxYEINQqQJozHpGmoYGUhFEA8Z6Yezl6zffydKUxF3zFySgKNJTMcUI2Mp3w956jMrj9BiVKm5DXdZcBd4OaiBvNqjamHkRwInnMQGM6T10HOlCVKkDMWMLMp+oolEeIYmZGhhjDjRQbo0vYB1y0RwLJR9sYFLdn0iRVzrOQ+tkiMz1du9jPyrN0zM+DFIaSwTQ2K8WjROGDQCZgnAiCqCDZtbgLCi1ivEU6QQNjanjZ86XpBm5rJvNtaHfFGuw3UmsyEN/9jUITYRdsOU/0NTvDsgNUlsqiKyAdqTeNsH2AW9ZsO7azTf7mut5/w4JXAFrsEN8MADaIFX0AZdgIEEn+ALfDtF59ZxneZK6hTymUuwUc7TL+fPxtw=</latexit>
(c) Pathwise Derivative Figure 2.5: Stochastic Gradient Estimation. (a)Variational inference with para- metric approximate posteriors requires optimizingLw.r.t. the distribution parameters, Ξ». This often requires estimating stochastic gradients (red dotted lines). (b) The score function estimator (Eq. 2.26) is applicable to any distribution, but suffers from high variance. Note: the solid red arrow denotes the per-sample objective,π(x,z;π). (c)The pathwise derivative estimator (Eq. 2.27), in contrast, has lower variance, but is less widely applicable.
computation graph (Schulman et al., 2015), containing the stochastic sampling operation z βΌ π(z|x). Unfortunately, stochastic sampling is non-differentiable, preventing us from simply differentiating through the chainΞ»βzβ L. Schulman et al., 2015 highlight two stochastic gradient estimators to tackle this issue: the score functionestimator (Glynn, 1990; Williams, 1992; Fu, 2006) and thepathwise derivativeestimator (Kingma and Welling, 2014; Rezende, Mohamed, and Wierstra, 2014; Titsias and LΓ‘zaro-Gredilla, 2014), each of which allows us to calculate stochastic estimates ofβΞ»L.
The score function estimator, sometimes referred to as the REINFORCE estimator (Williams, 1992), is generally applicable to any parametric distribution. WithπΞ»(z|x) denoting the dependence ofπ on Ξ», as well as L (x;π, π) = EzβΌπΞ»(z|x) [π(x,z;π)], the score function estimator allows us to expressβΞ»L as
βΞ»EzβΌπΞ»(z|x) [π(x,z;π)] =EzβΌπΞ»(z|x) [βΞ»logπΞ»(z|x)π(x,z;π)]. (2.26) While the score function estimator is unbiased and can be successfully utilized for variational inference (Gregor et al., 2014; Ranganath, Gerrish, and Blei, 2014; Mnih and Gregor, 2014; Mnih and Rezende, 2016), in practice, it requires considerable tuning and computation, owing to its empirically high variance. Alternatively, the pathwise derivative estimator, sometimes referred to as the reparameterization estimator (Kingma and Welling, 2014), when applicable, allows us to obtain unbiased
gradient estimates with considerably lower variance. This is accomplished by reparameterizingzin terms of an auxiliary random variable, enabling differentiation through stochastic sampling. The most common example is reparameterizing z βΌ N (z;Β΅π,diag(Οπ2)) as z =Β΅π +Οπ, whereβΌ N (;0,I) and denotes element-wise multiplication. More generally, for some deterministic functionπ, if we can express z= π(Ξ»,) withβΌ π(), then the pathwise derivative estimator allows us to expressβΞ»L as
βΞ»EzβΌπΞ»(z|x) [π(x,z;π)] =EβΌπ()[βΞ»π(x, π(Ξ»,);π)]. (2.27) With these stochastic gradient estimators in hand, we can perform stochastic gradient- based optimization of L (x;π, π) w.r.t. Ξ» (Ranganath, Gerrish, and Blei, 2014):
estimating βΞ»L by sampling z (or ) and updating Ξ» using stochastic gradient ascent. Unfortunately, a naΓ―ve implementation of this inference optimization (E- step) procedure scales poorly to large models and data sets. This is because many gradient steps must be performed per example, just to provide one estimate of L (x;π, π)for M-step optimization. More sophisticated approaches cache previous estimates ofΞ»for each example to initialize future inference optimization. However, beyond the additional memory overhead, this approach still requires tuning the inference optimizerβs learning rate and is not applicable to online learning settings.
In order to apply variational inference more broadly, we require a simple technique for dramatically improving the efficiency of inference optimization. Amortization (Gershman and Goodman, 2014), a form of meta-optimization, offers a viable solution.
Amortized Variational Inference
Amortization, in a generic sense, refers to spreading out costs. In amortized inference, these βcostsβ are the computational costs of performing inference optimization.
Thus, rather than separately optimizingΞ»for each data example, we amortize this optimization cost using a learned optimizer, i.e., an inference model. By using this meta-optimization procedure, we can perform inference optimization far more efficiently for each example, with a negligible cost for learning the inference model.
The concept of inference models is deeply embedded with deep latent variable models, popularized by the Helmholtz Machine (Dayan et al., 1995), which was formulated as anautoencoder(Ballard, 1987). Formally, in such setups, the inference model is a direct mapping fromxtoΞ»:
Ξ»β ππ(x), (2.28)
35 Β΅ (x)
<latexit sha1_base64="f4POFtkrBQtlxidt2haan9e3VOs=">AAACkXicdVHLSgMxFE3HV62vane6CRZBN2WmCuqu6EZwo9AXdIYhk2baYDIJSUYsw4Bf41Z/x78xU2dhrV4IHM59nZwbSUa1cd3PirOyura+Ud2sbW3v7O7V9w/6WqQKkx4WTKhhhDRhNCE9Qw0jQ6kI4hEjg+jptsgPnonSVCRdM5Mk4GiS0JhiZCwV1g/9iGc+T/PQl1MKT32OzDSKs5f8LKw33ZY7D7gMvBI0QRkP4X4l9McCp5wkBjOk9chzpQkypAzFjOQ1P9VEIvyEJmRkYYI40UE2/0QOTywzhrFQ9iUGztmfHRniWs94ZCsLjfp3riD/yo1SE18FGU1kakiCvxfFKYNGwMIROKaKYMNmFiCsqNUK8RQphI31bWFS1wuyQlwxZmF9xPPaCfzJFDKk4S+LdYhNhN0w5f/QFC83SE1S66oYWwPtSbzfB1gG/XbLO2+1Hy+anZvyOFVwBI7BKfDAJeiAO/AAegCDV/AG3sGH03CunY5T1jqVsqcBFsK5/wKbEczF</latexit>
(x)
<latexit sha1_base64="byTPipI6nSdWcCa+1Ot0FEheE08=">AAAClHicdVHLSgMxFE3HV62vquBGhGARdFNmVNCFi2IRXEmFVoXOMGTSTBtMJiHJiGWYlV/jVr/GvzHTdtGHXggczn2dnBtJRrVx3Z+Ss7S8srpWXq9sbG5t71R39560SBUmHSyYUC8R0oTRhHQMNYy8SEUQjxh5jl6bRf75jShNRdI2Q0kCjvoJjSlGxlJh9ciPeOZr2ucoD305oPDU58gMojh7z8/Cas2tu6OAi8CbgBqYRCvcLYV+T+CUk8RghrTueq40QYaUoZiRvOKnmkiEX1GfdC1MECc6yEb/yOGJZXowFsq+xMARO92RIa71kEe2stCo53MF+Veum5r4OshoIlNDEjxeFKcMGgELU2CPKoING1qAsKJWK8QDpBA21rqZSW0vyApxxZiZ9RHPKydwmilkSMPfZ+sQ6wu7YcD/oSlebJCapNZV0bMG2pN48wdYBE/nde+ifv54WWvcTo5TBofgGJwCD1yBBrgHLdABGHyAT/AFvp0D58ZpOnfjUqc06dkHM+E8/AKgM84S</latexit>
z<latexit sha1_base64="JJBRHMAw1+FJL17Y9ikHFHSRDGA=">AAACf3icdVHLTgIxFC3jC/EFunTTSIiuyIyaqDuiG5eY8IowIZ1OBxr6mLQdI074C7f6X/6NHWDBgN6kycm5j3N6bxAzqo3r/hScre2d3b3ifung8Oj4pFw57WiZKEzaWDKpegHShFFB2oYaRnqxIogHjHSDyVOW774RpakULTONic/RSNCIYmQs9TrgyIyDKP2YDctVt+7OA24CbwmqYBnNYaUwHIQSJ5wIgxnSuu+5sfFTpAzFjMxKg0STGOEJGpG+hQJxov10bnkGa5YJYSSVfcLAObvakSKu9ZQHtjKzqNdzGflXrp+Y6N5PqYgTQwReCEUJg0bC7P8wpIpgw6YWIKyo9QrxGCmEjd1SblLL89PMXDYmJx/wWakGV5nMRmz4e74OsZG0CmP+D03xZkOsSWK3KkO7QHsSb/0Am6BzXfdu6tcvt9XG4/I4RXAOLsAV8MAdaIBn0ARtgIEAn+ALfDsF59KpO+6i1Ckse85ALpyHX0lCxig=</latexit>
β<latexit sha1_base64="mlLz5PRn6k/2I/uXqfF1v+kMzi4=">AAACgnicdVFNS0JBFB1fVmZfWss2QyJEC3kvg1q0kNq0NMgK9CHzxqsOzVcz8yJ5+Dva1s/q3zRPXWTWhQuHc78O5yaaM+vC8KsQrBXXNzZLW+Xtnd29/Ur14MGq1FDoUMWVeUqIBc4kdBxzHJ60ASISDo/J801ef3wFY5mS926iIRZkJNmQUeI8FfcSkfVAW8aVnPYrtbARzgKvgmgBamgR7X610O8NFE0FSEc5sbYbhdrFGTGOUQ7Tci+1oAl9JiPoeiiJABtnM9VTXPfMAA+V8SkdnrE/JzIirJ2IxHcK4sb2dy0n/6p1Uze8jDMmdepA0vmhYcqxUzi3AA+YAer4xANCDfNaMR0TQ6jzRi1tuo/iLBeXr1k6n4hpuY5/MrkM7cTbch/hI+UvjMU/NKOrA9pC6l1VA2+gf0n0+wGr4OGsETUbZ3fntdb14jkldISO0QmK0AVqoVvURh1E0Qt6Rx/oMygGp0EUNOetQWExc4iWIrj6Bk+fx38=</latexit>
β€<latexit sha1_base64="Nkk3T+niU2cCEWunvav3KuI5f5w=">AAACdnicdVHLSgMxFE3HV62vVpeCBEtRXJSZKuiy6MZlC31BO5RMetuGJpMhyYhl6Be41Y/zT1yaabvoQy8EDuc+zsm9QcSZNq77nXF2dvf2D7KHuaPjk9OzfOG8pWWsKDSp5FJ1AqKBsxCahhkOnUgBEQGHdjB5SfPtN1CaybBhphH4goxCNmSUGEvV7/r5olt254G3gbcERbSMWr+Q6fcGksYCQkM50brruZHxE6IMoxxmuV6sISJ0QkbQtTAkArSfzJ3OcMkyAzyUyr7Q4Dm72pEQofVUBLZSEDPWm7mU/CvXjc3wyU9YGMUGQroQGsYcG4nTb+MBU0ANn1pAqGLWK6Zjogg1djlrkxqen6Tm0jFr8oGY5Up4lUltREa8r9cRPpJWYSz+oRndbog0xHarcmAXaE/ibR5gG7QqZe++XKk/FKvPy+Nk0SW6RrfIQ4+oil5RDTURRYA+0Cf6yvw4V07JuVmUOpllzwVaC8f9BZZ5wrc=</latexit>
+
<latexit sha1_base64="Z+9aU0F2yHtHC/FjEFUaW/dEtUg=">AAACdnicdVHLSgMxFE3HV62vVpeCBEtREMpMFXRZdOOyhb6gHUomvW1Dk8mQZMQy9Avc6sf5Jy7NtF30oRcCh3Mf5+TeIOJMG9f9zjg7u3v7B9nD3NHxyelZvnDe0jJWFJpUcqk6AdHAWQhNwwyHTqSAiIBDO5i8pPn2GyjNZNgw0wh8QUYhGzJKjKXqd/180S2788DbwFuCIlpGrV/I9HsDSWMBoaGcaN313Mj4CVGGUQ6zXC/WEBE6ISPoWhgSAdpP5k5nuGSZAR5KZV9o8Jxd7UiI0HoqAlspiBnrzVxK/pXrxmb45CcsjGIDIV0IDWOOjcTpt/GAKaCGTy0gVDHrFdMxUYQau5y1SQ3PT1Jz6Zg1+UDMciW8yqQ2IiPe1+sIH0mrMBb/0IxuN0QaYrtVObALtCfxNg+wDVqVsndfrtQfitXn5XGy6BJdo1vkoUdURa+ohpqIIkAf6BN9ZX6cK6fk3CxKncyy5wKtheP+ApiLwrg=</latexit>
Figure 2.6: Variational Autoencoder (VAE). VAEs combine direct amortization (Eq. 2.28) and the pathwise derivative estimator (Eq. 2.27, top) with Gaussian approximate posteriors to train deep latent variable models. In the model diagram (center), the amortized inference model (left) acts as anencoder, with the conditional likelihood (right) acting as adecoder. Each are parameterized by deep networks.
where ππis a model (deep network) with parametersπ. Conventionally, we denote the approximate posterior as ππ(z|x) to denote the parameterization by π. Now, rather than optimizingΞ»using gradient-based techniques, we periodically update π usingβπL = πLπΞ»ππ πΞ», thereby letting ππ learn to optimize Ξ». This procedure is incredibly simple, as we only need to tune the learning rate forπ, and efficient, as we have an estimate ofΞ»after only one forward pass through ππ. Amortization is also widely applicable: if we can estimateβΞ»L using stochastic gradient estimation (see above), we can continue differentiating through the chainπβ Ξ»βzβ L.
When direct amortization is combined with the pathwise derivative estimator in deep latent variable models, the resulting setup is referred to as avariational autoencoder (VAE)(Kingma and Welling, 2014; Rezende, Mohamed, and Wierstra, 2014). In this autoencoder interpretation,ππ(z|x)is anencoder,zis the latentcode, andππ(x|z) is a decoder. A computation graph is shown in Figure 2.6. This direct encoding scheme seems intuitively obvious: in the same way that ππ(x|z) directly mapszto a distribution overx,ππ(z|x) directly mapsxto a distribution overz. Indeed, with perfect knowledge of ππ(x,z), ππ could act as a lookup table, precisely mapping eachxto the corresponding optimalΞ». However, as we discuss in Chapter 3, there are a number of subtle issues that can arise with direct amortization, forming one of