• Tidak ada hasil yang ditemukan

Variational Inference

Chapter II: Background

2.3 Variational Inference

2.3 Variational Inference

Algorithm 1Variational Expectation Maximization (EM)

1: Input: model π‘πœƒ(x,z), data examplesx(1), . . . ,x(𝑁) ∼ 𝑝data(x)

2: whileπœƒnot convergeddo

3: for x=x(1), . . . ,x(𝑁) do

4: π‘ž(z|x) ←arg maxπ‘žL (x;π‘ž, πœƒ) ⊲inference (E-step)

5: end for

6: πœƒβ† arg maxπœƒ 1 𝑁

Í𝑁

𝑖=1L (x(𝑖);π‘ž, πœƒ) ⊲learning (M-step)

7: end while

Rearranging terms in Eq. 2.20, we have

logπ‘πœƒ(x) =L (x;π‘ž, πœƒ) +𝐷KL(π‘ž(z|x) ||π‘πœƒ(z|x)). (2.23) Because KL divergence is non-negative, we see that L (x;π‘ž, πœƒ) ≀ logπ‘πœƒ(x), with equality when π‘ž(z|x) = π‘πœƒ(z|x). As the LHS of Eq. 2.23 does not depend on π‘ž(z|x), maximizing L (x;π‘ž, πœƒ) w.r.t.π‘žimplicitly minimizes 𝐷KL(π‘ž(z|x) ||π‘πœƒ(z|x)) w.r.t.π‘ž. Together, these statements imply that maximizingL (x;π‘ž, πœƒ)w.r.t.π‘žtightens the lower bound on logπ‘πœƒ(x). With this tightened lower bound, we can then maximize L (x;π‘ž, πœƒ) w.r.t.πœƒ. This alternating optimization process is referred to as the variational expectation maximization (EM) algorithm (Dempster, Laird, and Rubin, 1977; Neal and G. E. Hinton, 1998) (Algorithm 1), consisting of approximate inference (E-step) and learning (M-step). While Algorithm 1 iterates over the entire data set between parameter updates, one can instead perform inference for a mini-batch of data examples and perform stochastic gradient ascent onπœƒ(Hoffman et al., 2013). This version of the algorithm, referred to asstochastic variational inference (SVI), is used in practice to train deep latent variable models.

Interpreting the ELBO

The primary motivation for variational inference is in providing a tractable lower bound for model training. In the process, however, we are left with an approxi- mate posterior distribution, π‘ž(z|x). To gain insight into this distribution and the approximate inference procedure, we can rewrite the ELBO as

L (x;π‘ž, πœƒ) =EzβˆΌπ‘ž(z|x) [logπ‘πœƒ(x|z) +logπ‘πœƒ(z) βˆ’logπ‘ž(z|x)]. (2.24) Each of the terms in the objective guides the optimization ofπ‘ž(z|x). The first term quantifies the agreement with the data, i.e., reconstruction: samplingzfromπ‘ž, we want to assign high log-probability to the observation,x. The second term quantifies the agreement with the prior prediction: samplingzfromπ‘ž, we want to have high

7 pβœ“(z)

<latexit sha1_base64="IEFcZ6TdJDNiYnM2MXVVhnD1rxA=">AAACjHicdVHLSiNBFK20j3HiKzrMyk1hEHQTulUYQQRxQFwqGBWSpqmu3E6KVHUVVbfF2ORj3OoXzd9MdczCGL1w4XDu63BuaqRwGIb/asHC4tLyj5Wf9dW19Y3Nxtb2ndOF5dDmWmr7kDIHUuTQRoESHowFplIJ9+nwb1W/fwTrhM5vcWQgVqyfi0xwhp5KGr9N0sUBIKP7XcVwkGbl8/ggaTTDVjgJOg+iKWiSaVwnW7Wk29O8UJAjl8y5ThQajEtmUXAJ43q3cGAYH7I+dDzMmQIXlxP9Y7rnmR7NtPWZI52wHydKppwbqdR3Vhrd51pFflXrFJidxKXITYGQ8/dDWSEpalqZQXvCAkc58oBxK7xWygfMMo7esplNt1FcVuKqNTPnUzWu79GPTCXDoHqa7WOyr/2FgfqGFnx+wDgovKu65w30L4k+P2Ae3B22oqPW4c1x8/xi+pwVskN2yT6JyB9yTq7INWkTTkryQl7JW7ARHAenwdl7a1CbzvwiMxFc/gdUOMp5</latexit>

q(z|x)

<latexit sha1_base64="GKYv3YBzSW+vcCnXm4SsDmmxo+c=">AAACknicdVHLTgIxFC3jG1+g7tg0EhLckBk00bhC3bhwoQkgCUxIp3SgoZ2ObceA4yz8Grf6Of6NHcSEh96kyem59/aenuuFjCpt218Za2V1bX1jcyu7vbO7t5/LHzSViCQmDSyYkC0PKcJoQBqaakZaoSSIe4w8esObNP/4TKSiIqjrcUhcjvoB9SlG2lDdXOEJljsc6YHnxy8JfIW/l1Fy0s0V7Yo9CbgMnCkogmncd/OZbqcncMRJoDFDSrUdO9RujKSmmJEk24kUCREeoj5pGxggTpQbT36RwJJhetAX0pxAwwk72xEjrtSYe6Yy1agWcyn5V64daf/CjWkQRpoE+GeQHzGoBUwtgT0qCdZsbADCkhqtEA+QRFgb4+ZeqjtunIpLn5kb7/EkW4KzTCoj1Hw0X4dYX5gJA/4PTfFyQ6hIZFwVPWOgWYmzuIBl0KxWnNNK9eGsWLueLmcTFMAxKAMHnIMauAX3oAEweAPv4AN8WkfWpXVl3fyUWplpzyGYC+vuG+HlzN8=</latexit>

pβœ“(x|z)

<latexit sha1_base64="60DGexcpR/Fv1prZWkg4K8IbvkM=">AAACmnicdVFNTxsxEHW2LdC0fLXH9mARIdFLtBuQ4BjBhaoXkBJASlaR15lNLOy1Zc+ihGV/AL+GK/0p/Tf1hiARAiNZen7zxjN+kxgpHIbhv1rw4eOnldW1z/UvX9c3Nre2v104nVsOXa6ltlcJcyBFBl0UKOHKWGAqkXCZXJ9U+csbsE7orINTA7Fio0ykgjP01GCrQc2gj2NARvf6iuE4SYtJSe/o8+W2/OVVYTOcBV0G0Rw0yDzOBtu1QX+oea4gQy6Zc70oNBgXzKLgEsp6P3dgGL9mI+h5mDEFLi5mvynprmeGNNXWnwzpjH1ZUTDl3FQlXlnN6F7nKvKtXC/H9CguRGZyhIw/NUpzSVHTyho6FBY4yqkHjFvhZ6V8zCzj6A1ceKkTxUU1XPXMQvtElfVd+pKpxjCoJos6Jkfadxird2jBlwuMg9y7qofeQL+S6PUClsFFqxntN1vnB4328Xw5a+QH2SF7JCKHpE1OyRnpEk7uyQN5JH+Dn8Fx8Dv48yQNavOa72Qhgs5/fj/QHw==</latexit>

Eq[logpβœ“(x|z)]

<latexit sha1_base64="ThBf6Fct1myVHSay2XFe4My8dic=">AAACunicdVFLbxMxEHaWVwmvFI5cLKJK5RLtFiQ4cKiokDgWqWkrZVcr25ndtWKvjT2LGpb9QfwarvTf4E2C1DQwkuXP33zz8Ay3SnqM4+tBdOfuvfsP9h4OHz1+8vTZaP/5uTeNEzAVRhl3yZkHJWuYokQFl9YB01zBBV+c9P6Lb+C8NPUZLi1kmpW1LKRgGKh8dJJqhhXn7acu/0pTBQXOwmVKavMUK0BGD9eSor3q6A/69/G9e01TJ8sKs3w0jifxyuguSDZgTDZ2mu8P8nRuRKOhRqGY97Mktpi1zKEUCrph2niwTCxYCbMAa6bBZ+3qtx09CMycFsaFUyNdsTcjWqa9X2oelH2v/ravJ//lmzVYvM9aWdsGoRbrQkWjKBraj47OpQOBahkAE06GXqmomGMCw4C3Mp0lWds316fZKs91NzygN5m+DYv6alvHVGlChUr/h5ZiN8B6aMJUzTwMMKwkub2AXXB+NEneTI6+vB0ff9wsZ4+8JK/IIUnIO3JMPpNTMiWC/CS/yG9yHX2IeCSjxVoaDTYxL8iWRfgHlrzdbQ==</latexit>

Eq

ο£Ώ logq(z|x)

pβœ“(z)

<latexit sha1_base64="bSdQSp0iZL3ytr3Aon9aL1YVVZc=">AAAC0nicdVHLatwwFNW4r3T6mqTLbkSHQLLoYCeFdhkaAl2mMJMEbGNkjWyLSJYiXZdMVC9Ktv2pfka/oNv2DyrPTGEm014QHJ177oNzcy24hTD80Qvu3X/w8NHW4/6Tp8+evxhs75xZ1RjKJlQJZS5yYpngNZsAB8EutGFE5oKd55fHXf78MzOWq3oMM81SScqaF5wS8FQ2iN8kkkCV5+6kza5wIlgBcSJUiZPCEOqu8N5CULibFn/Bfz/X7X7rdJZAxYCsavZbnBheVpDibDAMR+E88CaIlmCIlnGabfeyZKpoI1kNVBBr4yjUkDpigFPB2n7SWKYJvSQliz2siWQ2dXMXWrzrmSkulPGvBjxnVysckdbOZO6V3bb2bq4j/5WLGyjep47XugFW08WgohEYFO4sxVNuGAUx84BQw/2umFbEuwfe+LVO4yh13XJdm7XxuWz7u3iV6dbQIK/XdUSUyk+o5H9oTjcLtGWNd1VNvYH+JNHdA2yCs4NRdDg6+PR2ePRheZwt9Aq9RnsoQu/QEfqITtEEUfQd/US/0O9gHNwEX4PbhTToLWteorUIvv0BDFnnlA==</latexit>

x⇠pdata(x)

<latexit sha1_base64="Min9X1kSrlMsgf0t/tlM/bkvIlw=">AAACo3icdVFNa9tAEF2rX4nTNk57zGWpMSRQjJQE2mNoL4Vc0mAnAUuI1WpkL9nVLrujYiP0J/prek3+Rf9NVo4hcdwOLDzevJl5O5MZKRyG4d9O8OLlq9dvtra7O2/fvd/t7X24dLqyHMZcS22vM+ZAihLGKFDCtbHAVCbhKrv53uavfoF1QpcjXBhIFJuWohCcoafS3udYMZxlRT1vaOyEoiaNEeZoVZ0zZA09eBQcpr1+OAyXQTdBtAJ9sorzdK+TxrnmlYISuWTOTaLQYFIzi4JLaLpx5cAwfsOmMPGwZApcUi+/1dCBZ3JaaOtfiXTJPq2omXJuoTKvbD2657mW/FduUmHxNalFaSqEkj8MKipJUdN2RzQXFjjKhQeMW+G9Uj5jlnH0m1zrNIqSujXXtlkbn6mmO6BPmdaGQTVf1zE51X7CTP2HFnyzwDio/FZ17hfoTxI9P8AmuDwaRsfDo58n/dNvq+NskX3yiRyQiHwhp+QHOSdjwslv8ofckrtgEJwFF8HoQRp0VjUfyVoEyT0JRtSg</latexit>

latent space

observed space

(a) Computation Graph

20

structured models

hierarchical dynamical

(b) Hierarchical

20

structured models

hierarchical (c) Sequentialdynamical

Figure 2.4: ELBO Computation Graphs. (a)Basic computation graph for varia- tional inference. Outlined circles denote distributions. Smaller red circles denote terms in the ELBO objective. Arrows, again, denote conditional dependencies. This notation can be used to express (b) hierarchical and (c) sequential models with various model dependencies.

log-probability under the prior. The final term is theentropyofπ‘ž(z|x), quantifying the spread or uncertainty of the distribution. Maximizing this term encouragesπ‘ž(z|x) to remain maximally uncertain, all else being equal (Jaynes, 1957). Often, the final two terms are combined, as in Eq. 2.22, with𝐷KL(π‘ž(z|x) ||π‘πœƒ(z)) interpreted as a form of β€œregularization,” penalizing deviations ofπ‘ž(z|x)from π‘πœƒ(z).

This KL divergence term has been the focus of recent investigations. In particular, one can interpret the ELBO as a Lagrangian, with Lagrange multiplier 𝛽=1:

L (x;π‘ž, πœƒ) =EzβˆΌπ‘ž(z|x) [logπ‘πœƒ(x|z)] βˆ’ 𝛽 𝐷KL(π‘ž(z|x) ||π‘πœƒ(z)). (2.25) Thus, the ELBO can be seen as expressing a constrained optimization problem, with the regularizing constraint on𝐷KL(π‘ž(z|x) ||π‘πœƒ(z)) mediated by 𝛽. In models with flexible, e.g., autoregressive, conditional likelihoods, the model may not utilizez, in which case the KL term results in a local maximum atπ‘ž(z|x) = π‘πœƒ(z). This is the so-called β€œdying units” problem (Bowman et al., 2016), in which the latent variables are effectively unused. A common approach is to anneal 𝛽from 0β†’ 1 during the course of training (Bowman et al., 2016; SΓΈnderby et al., 2016), though alternative schemes exist (Kingma, Salimans, et al., 2016; X. Chen et al., 2017). We can also dynamically adjust 𝛽to maintain a particular KL or conditional likelihood constraint (Rezende and Viola, 2018). And by setting 𝛽 > 1 we can over-regularize π‘ž(z|x), and with an independent prior, this can yield more disentangled latent variables (Higgins et al., 2017; Burgess et al., 2018). In these latter cases, where 𝛽 β‰  1, L (x;π‘ž, πœƒ)is no longer a valid lower bound on logπ‘πœƒ(x). Alemi et al., 2018 provide

an information-theoretic framework to make sense of these ideas, identifying the connection between the two terms in the ELBO (Eq. 2.25) and the concepts of distortionandraterespectively. In short, adjusting 𝛽allows one to target particular rate-distortion trade-offs, which, in turn, result in bounds on the mutual information between 𝑋 and 𝑍. This perspective allows one to consider latent variable models that achieve varying compression rates in complex, natural data domains, such as images (BallΓ©, Laparra, and Simoncelli, 2017) and video (Lombardo et al., 2019).

As with autoregressive models (Figure 2.3), we can represent latent variable models and the ELBO objective as a computation graph, again breaking the dependency structure graphs from Figure 2.2 into separate distributions and terms in the objective.

In Figure 2.4, we illustrate examples of these graphs. Each variable contains a red circle, denoting a term in the ELBO objective. In comparison with the fully-observed autoregressive model, we now have an additional objective term for the latent variable, corresponding to the KL divergence. This graphical representation also allows us to visualize the variational objective for more complex hierarchical and sequential models (Figures 2.4b & 2.4c).

Inference Optimization

In deriving variational inference and variational EM, we described inference opti- mization, the E-step, as the process of maximizingL (x;π‘ž, πœƒ) w.r.t. the distribution π‘ž(z|x). The fact that probability distributions are functions makes L (x;π‘ž, πœƒ) a functionaland makes inference optimization a variational calculus problem. One could solve this optimization problem using a variety of techniques, ranging from gradient-free, e.g., evolution strategies, to first-order gradient-based, e.g., black- box variational inference (Ranganath, Gerrish, and Blei, 2014), to second-order gradient-based, e.g., the Laplace approximation (Friston et al., 2007; Park, C. Kim, and G. Kim, 2019). Despite this range, first-order gradient-based techniques are most common in practice; when combined with amortization (see below), they can result in exceedingly efficient optimization. For this reason, we exclusively focus on first-order techniques in the following chapters, which we now describe in more detail.

With a parametricπ‘ž(z|x), first-order inference optimization entails calculating the gradients ofL (x;π‘ž, πœƒ) w.r.t. the distribution parameters ofπ‘ž, which we refer to as Ξ». For instance, ifπ‘ž(z|x) =N (z;Β΅π‘ž,diag(Οƒπ‘ž2)), i.e.,Ξ» =

Β΅π‘ž,Οƒπ‘ž

, then we must calculateβˆ‡Β΅

π‘žL andβˆ‡Οƒ

π‘žL. We can consider the ELBO as an example of a stochastic

8

(a) Inference

10

z<latexit sha1_base64="JJBRHMAw1+FJL17Y9ikHFHSRDGA=">AAACf3icdVHLTgIxFC3jC/EFunTTSIiuyIyaqDuiG5eY8IowIZ1OBxr6mLQdI074C7f6X/6NHWDBgN6kycm5j3N6bxAzqo3r/hScre2d3b3ifung8Oj4pFw57WiZKEzaWDKpegHShFFB2oYaRnqxIogHjHSDyVOW774RpakULTONic/RSNCIYmQs9TrgyIyDKP2YDctVt+7OA24CbwmqYBnNYaUwHIQSJ5wIgxnSuu+5sfFTpAzFjMxKg0STGOEJGpG+hQJxov10bnkGa5YJYSSVfcLAObvakSKu9ZQHtjKzqNdzGflXrp+Y6N5PqYgTQwReCEUJg0bC7P8wpIpgw6YWIKyo9QrxGCmEjd1SblLL89PMXDYmJx/wWakGV5nMRmz4e74OsZG0CmP+D03xZkOsSWK3KkO7QHsSb/0Am6BzXfdu6tcvt9XG4/I4RXAOLsAV8MAdaIBn0ARtgIEAn+ALfDsF59KpO+6i1Ckse85ALpyHX0lCxig=</latexit>

<latexit sha1_base64="HlSpur9H0Rx2BbfN6EKhfq3sCbs=">AAACgXicdVHLSgMxFE2nPmp9tbp0EywF3ZSZKii4KbpxWaEv6Awlk0nb0GQSkoxYhv6GW/0t/8ZMO4s+9ELgcO65uYdzQ8moNq77U3CKe/sHh6Wj8vHJ6dl5pXrR0yJRmHSxYEINQqQJozHpGmoYGUhFEA8Z6Yezl6zffydKUxF3zFySgKNJTMcUI2Mp3w956jMrj9BiVKm5DXdZcBd4OaiBvNqjamHkRwInnMQGM6T10HOlCVKkDMWMLMp+oolEeIYmZGhhjDjRQbo0vYB1y0RwLJR9sYFLdn0iRVzrOQ+tkiMz1du9jPyrN0zM+DFIaSwTQ2K8WjROGDQCZgnAiCqCDZtbgLCi1ivEU6QQNjanjZ86XpBm5rJvNtaHfFGuw3UmsyEN/9jUITYRdsOU/0NTvDsgNUlsqiKyAdqTeNsH2AW9ZsO7azTf7mut5/w4JXAFrsEN8MADaIFX0AZdgIEEn+ALfDtF59ZxneZK6hTymUuwUc7TL+fPxtw=</latexit>

logq

<latexit sha1_base64="Dq+cAEIkibq9J0Jy0Skhk1KMXas=">AAACjHicdVFdS8MwFM3q15w6p+KTL8Ex8Gm0U1AQYSiIjxP2BWspaZptYUlTk1QcpT/GV/1F/hvTbQ/70AuBw7n35B7ODWJGlbbtn4K1tb2zu1fcLx0cHpWPKyenXSUSiUkHCyZkP0CKMBqRjqaakX4sCeIBI71g8pT3e+9EKiqitp7GxONoFNEhxUgbyq+cu0yM4JufugFPXWaEIcoyv1K16/as4CZwFqAKFtXyTwq+GwqccBJpzJBSA8eOtZciqSlmJCu5iSIxwhM0IgMDI8SJ8tKZ/wzWDBPCoZDmRRrO2GVFirhSUx6YSY70WK33cvKv3iDRwzsvpVGcaBLh+aJhwqAWMA8DhlQSrNnUAIQlNV4hHiOJsDaRrfzUdrw0N5d/s7I+4FmpBpeZ3Eas+cfqHGIjYTaM+T80xZuCWJHEpCpCE6A5ibN+gE3QbdSd63rj9abafFwcpwguwCW4Ag64BU3wAlqgAzBIwSf4At9W2bqx7q2H+ahVWGjOwEpZz79GS8rt</latexit>

⇀

<latexit sha1_base64="Nkk3T+niU2cCEWunvav3KuI5f5w=">AAACdnicdVHLSgMxFE3HV62vVpeCBEtRXJSZKuiy6MZlC31BO5RMetuGJpMhyYhl6Be41Y/zT1yaabvoQy8EDuc+zsm9QcSZNq77nXF2dvf2D7KHuaPjk9OzfOG8pWWsKDSp5FJ1AqKBsxCahhkOnUgBEQGHdjB5SfPtN1CaybBhphH4goxCNmSUGEvV7/r5olt254G3gbcERbSMWr+Q6fcGksYCQkM50brruZHxE6IMoxxmuV6sISJ0QkbQtTAkArSfzJ3OcMkyAzyUyr7Q4Dm72pEQofVUBLZSEDPWm7mU/CvXjc3wyU9YGMUGQroQGsYcG4nTb+MBU0ANn1pAqGLWK6Zjogg1djlrkxqen6Tm0jFr8oGY5Up4lUltREa8r9cRPpJWYSz+oRndbog0xHarcmAXaE/ibR5gG7QqZe++XKk/FKvPy+Nk0SW6RrfIQ4+oil5RDTURRYA+0Cf6yvw4V07JuVmUOpllzwVaC8f9BZZ5wrc=</latexit>

(b) Score Function

9

z<latexit sha1_base64="JJBRHMAw1+FJL17Y9ikHFHSRDGA=">AAACf3icdVHLTgIxFC3jC/EFunTTSIiuyIyaqDuiG5eY8IowIZ1OBxr6mLQdI074C7f6X/6NHWDBgN6kycm5j3N6bxAzqo3r/hScre2d3b3ifung8Oj4pFw57WiZKEzaWDKpegHShFFB2oYaRnqxIogHjHSDyVOW774RpakULTONic/RSNCIYmQs9TrgyIyDKP2YDctVt+7OA24CbwmqYBnNYaUwHIQSJ5wIgxnSuu+5sfFTpAzFjMxKg0STGOEJGpG+hQJxov10bnkGa5YJYSSVfcLAObvakSKu9ZQHtjKzqNdzGflXrp+Y6N5PqYgTQwReCEUJg0bC7P8wpIpgw6YWIKyo9QrxGCmEjd1SblLL89PMXDYmJx/wWakGV5nMRmz4e74OsZG0CmP+D03xZkOsSWK3KkO7QHsSb/0Am6BzXfdu6tcvt9XG4/I4RXAOLsAV8MAdaIBn0ARtgIEAn+ALfDsF59KpO+6i1Ckse85ALpyHX0lCxig=</latexit>

✏

<latexit sha1_base64="mlLz5PRn6k/2I/uXqfF1v+kMzi4=">AAACgnicdVFNS0JBFB1fVmZfWss2QyJEC3kvg1q0kNq0NMgK9CHzxqsOzVcz8yJ5+Dva1s/q3zRPXWTWhQuHc78O5yaaM+vC8KsQrBXXNzZLW+Xtnd29/Ur14MGq1FDoUMWVeUqIBc4kdBxzHJ60ASISDo/J801ef3wFY5mS926iIRZkJNmQUeI8FfcSkfVAW8aVnPYrtbARzgKvgmgBamgR7X610O8NFE0FSEc5sbYbhdrFGTGOUQ7Tci+1oAl9JiPoeiiJABtnM9VTXPfMAA+V8SkdnrE/JzIirJ2IxHcK4sb2dy0n/6p1Uze8jDMmdepA0vmhYcqxUzi3AA+YAer4xANCDfNaMR0TQ6jzRi1tuo/iLBeXr1k6n4hpuY5/MrkM7cTbch/hI+UvjMU/NKOrA9pC6l1VA2+gf0n0+wGr4OGsETUbZ3fntdb14jkldISO0QmK0AVqoVvURh1E0Qt6Rx/oMygGp0EUNOetQWExc4iWIrj6Bk+fx38=</latexit> <latexit sha1_base64="HlSpur9H0Rx2BbfN6EKhfq3sCbs=">AAACgXicdVHLSgMxFE2nPmp9tbp0EywF3ZSZKii4KbpxWaEv6Awlk0nb0GQSkoxYhv6GW/0t/8ZMO4s+9ELgcO65uYdzQ8moNq77U3CKe/sHh6Wj8vHJ6dl5pXrR0yJRmHSxYEINQqQJozHpGmoYGUhFEA8Z6Yezl6zffydKUxF3zFySgKNJTMcUI2Mp3w956jMrj9BiVKm5DXdZcBd4OaiBvNqjamHkRwInnMQGM6T10HOlCVKkDMWMLMp+oolEeIYmZGhhjDjRQbo0vYB1y0RwLJR9sYFLdn0iRVzrOQ+tkiMz1du9jPyrN0zM+DFIaSwTQ2K8WjROGDQCZgnAiCqCDZtbgLCi1ivEU6QQNjanjZ86XpBm5rJvNtaHfFGuw3UmsyEN/9jUITYRdsOU/0NTvDsgNUlsqiKyAdqTeNsH2AW9ZsO7azTf7mut5/w4JXAFrsEN8MADaIFX0AZdgIEEn+ALfDtF59ZxneZK6hTymUuwUc7TL+fPxtw=</latexit>

(c) Pathwise Derivative Figure 2.5: Stochastic Gradient Estimation. (a)Variational inference with para- metric approximate posteriors requires optimizingLw.r.t. the distribution parameters, Ξ». This often requires estimating stochastic gradients (red dotted lines). (b) The score function estimator (Eq. 2.26) is applicable to any distribution, but suffers from high variance. Note: the solid red arrow denotes the per-sample objective,𝑙(x,z;πœƒ). (c)The pathwise derivative estimator (Eq. 2.27), in contrast, has lower variance, but is less widely applicable.

computation graph (Schulman et al., 2015), containing the stochastic sampling operation z ∼ π‘ž(z|x). Unfortunately, stochastic sampling is non-differentiable, preventing us from simply differentiating through the chainΞ»β†’zβ†’ L. Schulman et al., 2015 highlight two stochastic gradient estimators to tackle this issue: the score functionestimator (Glynn, 1990; Williams, 1992; Fu, 2006) and thepathwise derivativeestimator (Kingma and Welling, 2014; Rezende, Mohamed, and Wierstra, 2014; Titsias and LΓ‘zaro-Gredilla, 2014), each of which allows us to calculate stochastic estimates ofβˆ‡Ξ»L.

The score function estimator, sometimes referred to as the REINFORCE estimator (Williams, 1992), is generally applicable to any parametric distribution. Withπ‘žΞ»(z|x) denoting the dependence ofπ‘ž on Ξ», as well as L (x;π‘ž, πœƒ) = EzβˆΌπ‘žΞ»(z|x) [𝑙(x,z;πœƒ)], the score function estimator allows us to expressβˆ‡Ξ»L as

βˆ‡Ξ»EzβˆΌπ‘žΞ»(z|x) [𝑙(x,z;πœƒ)] =EzβˆΌπ‘žΞ»(z|x) [βˆ‡Ξ»logπ‘žΞ»(z|x)𝑙(x,z;πœƒ)]. (2.26) While the score function estimator is unbiased and can be successfully utilized for variational inference (Gregor et al., 2014; Ranganath, Gerrish, and Blei, 2014; Mnih and Gregor, 2014; Mnih and Rezende, 2016), in practice, it requires considerable tuning and computation, owing to its empirically high variance. Alternatively, the pathwise derivative estimator, sometimes referred to as the reparameterization estimator (Kingma and Welling, 2014), when applicable, allows us to obtain unbiased

gradient estimates with considerably lower variance. This is accomplished by reparameterizingzin terms of an auxiliary random variable, enabling differentiation through stochastic sampling. The most common example is reparameterizing z ∼ N (z;Β΅π‘ž,diag(Οƒπ‘ž2)) as z =Β΅π‘ž +Οƒπ‘ž, where∼ N (;0,I) and denotes element-wise multiplication. More generally, for some deterministic function𝑔, if we can express z= 𝑔(Ξ»,) with∼ 𝑝(), then the pathwise derivative estimator allows us to expressβˆ‡Ξ»L as

βˆ‡Ξ»EzβˆΌπ‘žΞ»(z|x) [𝑙(x,z;πœƒ)] =EβˆΌπ‘()[βˆ‡Ξ»π‘™(x, 𝑔(Ξ»,);πœƒ)]. (2.27) With these stochastic gradient estimators in hand, we can perform stochastic gradient- based optimization of L (x;π‘ž, πœƒ) w.r.t. Ξ» (Ranganath, Gerrish, and Blei, 2014):

estimating βˆ‡Ξ»L by sampling z (or ) and updating Ξ» using stochastic gradient ascent. Unfortunately, a naΓ―ve implementation of this inference optimization (E- step) procedure scales poorly to large models and data sets. This is because many gradient steps must be performed per example, just to provide one estimate of L (x;π‘ž, πœƒ)for M-step optimization. More sophisticated approaches cache previous estimates ofΞ»for each example to initialize future inference optimization. However, beyond the additional memory overhead, this approach still requires tuning the inference optimizer’s learning rate and is not applicable to online learning settings.

In order to apply variational inference more broadly, we require a simple technique for dramatically improving the efficiency of inference optimization. Amortization (Gershman and Goodman, 2014), a form of meta-optimization, offers a viable solution.

Amortized Variational Inference

Amortization, in a generic sense, refers to spreading out costs. In amortized inference, these β€œcosts” are the computational costs of performing inference optimization.

Thus, rather than separately optimizingΞ»for each data example, we amortize this optimization cost using a learned optimizer, i.e., an inference model. By using this meta-optimization procedure, we can perform inference optimization far more efficiently for each example, with a negligible cost for learning the inference model.

The concept of inference models is deeply embedded with deep latent variable models, popularized by the Helmholtz Machine (Dayan et al., 1995), which was formulated as anautoencoder(Ballard, 1987). Formally, in such setups, the inference model is a direct mapping fromxtoΞ»:

λ← π‘“πœ™(x), (2.28)

35 Β΅ (x)

<latexit sha1_base64="f4POFtkrBQtlxidt2haan9e3VOs=">AAACkXicdVHLSgMxFE3HV62vane6CRZBN2WmCuqu6EZwo9AXdIYhk2baYDIJSUYsw4Bf41Z/x78xU2dhrV4IHM59nZwbSUa1cd3PirOyura+Ud2sbW3v7O7V9w/6WqQKkx4WTKhhhDRhNCE9Qw0jQ6kI4hEjg+jptsgPnonSVCRdM5Mk4GiS0JhiZCwV1g/9iGc+T/PQl1MKT32OzDSKs5f8LKw33ZY7D7gMvBI0QRkP4X4l9McCp5wkBjOk9chzpQkypAzFjOQ1P9VEIvyEJmRkYYI40UE2/0QOTywzhrFQ9iUGztmfHRniWs94ZCsLjfp3riD/yo1SE18FGU1kakiCvxfFKYNGwMIROKaKYMNmFiCsqNUK8RQphI31bWFS1wuyQlwxZmF9xPPaCfzJFDKk4S+LdYhNhN0w5f/QFC83SE1S66oYWwPtSbzfB1gG/XbLO2+1Hy+anZvyOFVwBI7BKfDAJeiAO/AAegCDV/AG3sGH03CunY5T1jqVsqcBFsK5/wKbEczF</latexit>

(x)

<latexit sha1_base64="byTPipI6nSdWcCa+1Ot0FEheE08=">AAAClHicdVHLSgMxFE3HV62vquBGhGARdFNmVNCFi2IRXEmFVoXOMGTSTBtMJiHJiGWYlV/jVr/GvzHTdtGHXggczn2dnBtJRrVx3Z+Ss7S8srpWXq9sbG5t71R39560SBUmHSyYUC8R0oTRhHQMNYy8SEUQjxh5jl6bRf75jShNRdI2Q0kCjvoJjSlGxlJh9ciPeOZr2ucoD305oPDU58gMojh7z8/Cas2tu6OAi8CbgBqYRCvcLYV+T+CUk8RghrTueq40QYaUoZiRvOKnmkiEX1GfdC1MECc6yEb/yOGJZXowFsq+xMARO92RIa71kEe2stCo53MF+Veum5r4OshoIlNDEjxeFKcMGgELU2CPKoING1qAsKJWK8QDpBA21rqZSW0vyApxxZiZ9RHPKydwmilkSMPfZ+sQ6wu7YcD/oSlebJCapNZV0bMG2pN48wdYBE/nde+ifv54WWvcTo5TBofgGJwCD1yBBrgHLdABGHyAT/AFvp0D58ZpOnfjUqc06dkHM+E8/AKgM84S</latexit>

z<latexit sha1_base64="JJBRHMAw1+FJL17Y9ikHFHSRDGA=">AAACf3icdVHLTgIxFC3jC/EFunTTSIiuyIyaqDuiG5eY8IowIZ1OBxr6mLQdI074C7f6X/6NHWDBgN6kycm5j3N6bxAzqo3r/hScre2d3b3ifung8Oj4pFw57WiZKEzaWDKpegHShFFB2oYaRnqxIogHjHSDyVOW774RpakULTONic/RSNCIYmQs9TrgyIyDKP2YDctVt+7OA24CbwmqYBnNYaUwHIQSJ5wIgxnSuu+5sfFTpAzFjMxKg0STGOEJGpG+hQJxov10bnkGa5YJYSSVfcLAObvakSKu9ZQHtjKzqNdzGflXrp+Y6N5PqYgTQwReCEUJg0bC7P8wpIpgw6YWIKyo9QrxGCmEjd1SblLL89PMXDYmJx/wWakGV5nMRmz4e74OsZG0CmP+D03xZkOsSWK3KkO7QHsSb/0Am6BzXfdu6tcvt9XG4/I4RXAOLsAV8MAdaIBn0ARtgIEAn+ALfDsF59KpO+6i1Ckse85ALpyHX0lCxig=</latexit>

✏<latexit sha1_base64="mlLz5PRn6k/2I/uXqfF1v+kMzi4=">AAACgnicdVFNS0JBFB1fVmZfWss2QyJEC3kvg1q0kNq0NMgK9CHzxqsOzVcz8yJ5+Dva1s/q3zRPXWTWhQuHc78O5yaaM+vC8KsQrBXXNzZLW+Xtnd29/Ur14MGq1FDoUMWVeUqIBc4kdBxzHJ60ASISDo/J801ef3wFY5mS926iIRZkJNmQUeI8FfcSkfVAW8aVnPYrtbARzgKvgmgBamgR7X610O8NFE0FSEc5sbYbhdrFGTGOUQ7Tci+1oAl9JiPoeiiJABtnM9VTXPfMAA+V8SkdnrE/JzIirJ2IxHcK4sb2dy0n/6p1Uze8jDMmdepA0vmhYcqxUzi3AA+YAer4xANCDfNaMR0TQ6jzRi1tuo/iLBeXr1k6n4hpuY5/MrkM7cTbch/hI+UvjMU/NKOrA9pC6l1VA2+gf0n0+wGr4OGsETUbZ3fntdb14jkldISO0QmK0AVqoVvURh1E0Qt6Rx/oMygGp0EUNOetQWExc4iWIrj6Bk+fx38=</latexit>

⇀<latexit sha1_base64="Nkk3T+niU2cCEWunvav3KuI5f5w=">AAACdnicdVHLSgMxFE3HV62vVpeCBEtRXJSZKuiy6MZlC31BO5RMetuGJpMhyYhl6Be41Y/zT1yaabvoQy8EDuc+zsm9QcSZNq77nXF2dvf2D7KHuaPjk9OzfOG8pWWsKDSp5FJ1AqKBsxCahhkOnUgBEQGHdjB5SfPtN1CaybBhphH4goxCNmSUGEvV7/r5olt254G3gbcERbSMWr+Q6fcGksYCQkM50brruZHxE6IMoxxmuV6sISJ0QkbQtTAkArSfzJ3OcMkyAzyUyr7Q4Dm72pEQofVUBLZSEDPWm7mU/CvXjc3wyU9YGMUGQroQGsYcG4nTb+MBU0ANn1pAqGLWK6Zjogg1djlrkxqen6Tm0jFr8oGY5Up4lUltREa8r9cRPpJWYSz+oRndbog0xHarcmAXaE/ibR5gG7QqZe++XKk/FKvPy+Nk0SW6RrfIQ4+oil5RDTURRYA+0Cf6yvw4V07JuVmUOpllzwVaC8f9BZZ5wrc=</latexit>

+

<latexit sha1_base64="Z+9aU0F2yHtHC/FjEFUaW/dEtUg=">AAACdnicdVHLSgMxFE3HV62vVpeCBEtREMpMFXRZdOOyhb6gHUomvW1Dk8mQZMQy9Avc6sf5Jy7NtF30oRcCh3Mf5+TeIOJMG9f9zjg7u3v7B9nD3NHxyelZvnDe0jJWFJpUcqk6AdHAWQhNwwyHTqSAiIBDO5i8pPn2GyjNZNgw0wh8QUYhGzJKjKXqd/180S2788DbwFuCIlpGrV/I9HsDSWMBoaGcaN313Mj4CVGGUQ6zXC/WEBE6ISPoWhgSAdpP5k5nuGSZAR5KZV9o8Jxd7UiI0HoqAlspiBnrzVxK/pXrxmb45CcsjGIDIV0IDWOOjcTpt/GAKaCGTy0gVDHrFdMxUYQau5y1SQ3PT1Jz6Zg1+UDMciW8yqQ2IiPe1+sIH0mrMBb/0IxuN0QaYrtVObALtCfxNg+wDVqVsndfrtQfitXn5XGy6BJdo1vkoUdURa+ohpqIIkAf6BN9ZX6cK6fk3CxKncyy5wKtheP+ApiLwrg=</latexit>

Figure 2.6: Variational Autoencoder (VAE). VAEs combine direct amortization (Eq. 2.28) and the pathwise derivative estimator (Eq. 2.27, top) with Gaussian approximate posteriors to train deep latent variable models. In the model diagram (center), the amortized inference model (left) acts as anencoder, with the conditional likelihood (right) acting as adecoder. Each are parameterized by deep networks.

where π‘“πœ™is a model (deep network) with parametersπœ™. Conventionally, we denote the approximate posterior as π‘žπœ™(z|x) to denote the parameterization by πœ™. Now, rather than optimizingΞ»using gradient-based techniques, we periodically update πœ™ usingβˆ‡πœ™L = πœ•Lπœ•Ξ»πœ•πœ• πœ™Ξ», thereby letting π‘“πœ™ learn to optimize Ξ». This procedure is incredibly simple, as we only need to tune the learning rate forπœ™, and efficient, as we have an estimate ofΞ»after only one forward pass through π‘“πœ™. Amortization is also widely applicable: if we can estimateβˆ‡Ξ»L using stochastic gradient estimation (see above), we can continue differentiating through the chainπœ™β†’ Ξ»β†’zβ†’ L.

When direct amortization is combined with the pathwise derivative estimator in deep latent variable models, the resulting setup is referred to as avariational autoencoder (VAE)(Kingma and Welling, 2014; Rezende, Mohamed, and Wierstra, 2014). In this autoencoder interpretation,π‘žπœ™(z|x)is anencoder,zis the latentcode, andπ‘πœƒ(x|z) is a decoder. A computation graph is shown in Figure 2.6. This direct encoding scheme seems intuitively obvious: in the same way that π‘πœƒ(x|z) directly mapszto a distribution overx,π‘žπœ™(z|x) directly mapsxto a distribution overz. Indeed, with perfect knowledge of π‘πœƒ(x,z), π‘“πœ™ could act as a lookup table, precisely mapping eachxto the corresponding optimalΞ». However, as we discuss in Chapter 3, there are a number of subtle issues that can arise with direct amortization, forming one of