Chapter IV: Amortized Variational Filtering
4.3 Variational Filtering
on an empirical basis for deterministic models (Lotter, Kreiman, and Cox, 2017;
Henaff, Zhao, and LeCun, 2017). PredNet (Lotter, Kreiman, and Cox, 2017) is a deterministic model that encodes prediction errors to perform inference. Like our method, this approach is inspired by hierarchical predictive coding (Rao and Ballard, 1999; Friston, 2005). Predictive coding, in turn, is an approximation to classical Bayesian filtering (SΓ€rkkΓ€, 2013), which updates the posterior from the prior using the log-likelihood (error) of the prediction. For linear Gaussian models, this manifests as the Kalman filter (Kalman, 1960), which uses prediction errors to perform exact inference.
Finally, several recent works have used particle filtering in conjunction with amortized inference to provide a tighter bound on the log-likelihood for sequential models (Maddison et al., 2017; Naesseth et al., 2018; Le et al., 2018). The techniques developed here can also be applied to this setting, providing an improved method for estimating the proposal distribution.
t
t+ 1InferenceOptimizat
ion(E-step)
t 1
Figure 4.1: Variational Filtering Inference. The diagram shows filtering inference within a sequential latent variable model, as outlined in Algorithm 2. The central gray region depicts inference optimization at stepπ‘, initialized at or near the corresponding prior (white), ππ(zπ‘|x<π‘,z<π‘). Sampling from the approximate posterior (blue) generates the conditional likelihood (green), ππ(xπ‘|x<π‘,zβ€π‘), which is evaluated at the observation (gray),xπ‘, to calculate the prediction error. This term is combined with the KL divergence between the approximate posterior and prior, yielding the step ELBO (red),Lπ‘(Eq. 4.7). Inference optimization (E-step) involves finding the approximate posterior that maximizes the step ELBO terms.
re-expressed as a reconstruction term and a KL divergence term:
Lπ‘ =Eπ(zπ‘|xβ€π‘,z<π‘) [logππ(xπ‘|x<π‘,zβ€π‘)] βπ·KL(π(zπ‘|xβ€π‘,z<π‘) ||ππ(zπ‘|x<π‘,z<π‘)). (4.7) The filtering ELBO in Eq. 4.5 is the sum of these step ELBO terms, each of which is evaluated according to expectations over past latent sequences. To perform filtering variational inference, we must find the set ofπ terms inπ(z1:π|x1:π)that maximize the filtering ELBO summation.
We now describe the variational filtering EM algorithm, given in Algorithm 2 and depicted in Figure 4.1, which optimizes Eq. 4.5. This algorithm sequentially optimizes each of the approximate posterior terms to perform filtering inference.
Consider the approximate posterior at stepπ‘, π(zπ‘|xβ€π‘,z<π‘). This term appears inL,
either directly or in expectations, in termsπ‘ throughπ of the summation:
L =
| {z }
steps on whichπ(zπ‘|xβ€π‘,z<π‘)depends
Le1+Le2+ Β· Β· Β· +Leπ‘β1+
terms in whichπ(zπ‘|xβ€π‘,z<π‘)appears
z }| {
Leπ‘+Leπ‘+1+ Β· Β· Β· +Leπβ1+Leπ . (4.8)
However, the filtering setting dictates that the optimization of the approximate posterior at each step can only condition on past and present variables, i.e., steps 1 throughπ‘. Therefore, of theπ terms inL, theonlyterm through which we can optimizeπ(zπ‘|xβ€π‘,z<π‘)is theπ‘thterm:
πβ(zπ‘|xβ€π‘,z<π‘) = arg max
π(zπ‘|xβ€π‘,z<π‘)
Leπ‘. (4.9)
OptimizingLeπ‘requires evaluating expectations over previous approximate posteriors.
Again, because approximate posterior estimates cannot be influenced by future variables, these past expectations remainfixedthrough the future. Thus, variational filtering (the variational E-step) can be performed by sequentially maximizing eachLπ‘
w.r.t.π(zπ‘|xβ€π‘,z<π‘), holding the expectations over past variables fixed. Conveniently, once the past expectations have been evaluated, inference optimization is entirely defined by the ELBO at that step.
For simple models, such as linear Gaussian models, these expectations may be computed exactly. However, in general, the expectations must be estimated through Monte Carlo samples fromπ, with inference optimization carried out using stochastic gradients (Ranganath, Gerrish, and Blei, 2014). As in the static setting, we can initializeπ(zπ‘|xβ€π‘,z<π‘) at (or near) the prior, ππ(zπ‘|x<π‘,z<π‘). This yields a simple interpretation: starting with π at the prior, we generate a prediction of the data through the likelihood, ππ(xπ‘|x<π‘,zβ€π‘), to evaluate the current step ELBO. Using the approximate posterior gradient, we then perform an inferenceupdateto the estimate ofπ. This resembles classical Bayesian filtering, where the posterior is updated from the prior prediction according to the likelihood of observations. Unlike the classical setting, reconstruction and update steps are repeated until inference convergence.
This resembles the extended Kalman filter (Kalman, 1960; Murphy, 2012), but is applicable to non-linear models with non-Gaussian distributions.
After inferring an optimal approximate posterior, learning (the variational M-step) can be performed by maximizing the total filtering ELBO w.r.t. the model parameters, π. As Eq. 4.5 is a summation and differentiation is a linear operation, βπL is the
Algorithm 2
1: Input: observation sequencex1:π, model ππ(x1:π,z1:π)
2: βπL =0 β²parameter gradient
3: forπ‘ =1toπ do
4: initializeπ(zπ‘|xβ€π‘,z<π‘) β²at/near ππ(zπ‘|x<π‘,z<π‘)
5: Leπ‘ :=Eπ(z<π‘|x<π‘,z<π‘β1) [Lπ‘]
6: π(zπ‘|xβ€π‘,z<π‘) =arg maxπLeπ‘ β²inference (E-Step)
7: βπL β βπL + βπLeπ‘ 8: end for
9: π =π+πΌβπL β²learning (M-Step)
sum of contributions from each of these terms:
βπL =
π
Γ
π‘=1
βπ
h EΓπ‘β1
π=1π(zπ|xβ€π,z< π) [Lπ‘]i
. (4.10)
Parameter gradients can be estimated online by accumulating the result from each term in the filtering ELBO. The parameters are then updated at the end of the sequence. For large datasets, stochastic estimates of parameter gradients can be obtained from a mini-batch of data examples (Hoffman et al., 2013).
Amortized Variational Filtering
Performing approximate inference optimization (Algorithm 2, Line 6) with traditional techniques can be computationally costly, requiring many iterations of gradient updates and hand-tuning of optimizer hyper-parameters. In online settings, with large models and data sets, this may be impractical. An alternative approach is to employ an amortized inference model, which can learn to maximize Lπ‘
w.r.t. π(zπ‘|xβ€π‘,z<π‘) more efficiently at each step. Note that Lπ‘ (Eq. 4.6) contains ππ(xπ‘,zπ‘|x<π‘,z<π‘)= ππ(xπ‘|x<π‘,zβ€π‘)ππ(zπ‘|x<π‘,z<π‘). The prior,ππ(zπ‘|x<π‘,z<π‘), varies across steps, constituting the latent dynamics. Direct inference models, which only encodexπ‘, do not have access to the prior and therefore cannot properly optimize π(zπ‘|xβ€π‘,z<π‘). Many inference models in the sequential setting attempt to account for this information by including hidden states (Chung, Kastner, et al., 2015; Fraccaro, S. K. SΓΈnderby, et al., 2016; Denton and Fergus, 2018). However, given the complexities of many generative models, it can be difficult to determine how to properly route the necessary prior information into the inference model. As a result, each SLVM has been proposed with an accompanying custom inference model.
We propose a simple and general alternative method for amortizing filtering in- ference that is agnostic to the particular form of the generative model. Iterative
amortized inference (Chapter 3) naturally accounts for the changing prior through the approximate posterior gradients or errors. These models are thus a natural candidate for performing inference at each step. Similar to Eq. 3.1, when π(zπ‘|xβ€π‘,z<π‘) is a parametric distribution with parametersΞ»π‘π, the inference update takes the form:
Ξ»ππ‘ β ππ(Ξ»π‘π,βΞ»π
π‘
Leπ‘). (4.11)
We refer to this setup asamortized variational filtering(AVF). Again, Eq. 4.11 offers just one particular encoding form for an iterative inference model. For instance,xπ‘
could be additionally encoded at each step. As noted in Chapter 3, in latent Gaussian models, precision-weighted errors provide an alternative inference optimization signal. There are two main benefits to using iterative inference models in the filtering setting:
β’ The approximate posterior is updated from the prior, so model capacity is utilized for inferencecorrectionsrather than re-estimating the latent dynamics at each step.
β’ These inference models contain all of the terms necessary to perform inference optimization, providing a simple model form that does not require any additional hidden states or inputs.
In practice, these advantages permit the use of relatively simple iterative inference models that can perform filtering inference efficiently and accurately. We demonstrate this in the following section.