Mean field Variational Inference via Wasserstein Gradient Flow
Variational inference (VI) provides an appealing alternative to traditional sampling-based approaches for implementing Bayesian inference due to its conceptual simplicity, statistical accuracy and computational scalability. However, common variational approximation schemes, such as the mean-field (MF) approximation, require certain conjugacy structure to facilitate efficient computation, which may add unnecessary restrictions to the viable prior distribution family and impose further constraints on the variational approximation family. In this work, we develop a general computational framework for implementing MF-VI via Wasserstein gradient flow (WGF), a gradient flow over the space of probability measures. When specialized to Bayesian latent variable models, we analyze the algorithmic convergence of an alternating minimization scheme based on a time-discretized WGF for implementing the MF approximation. In particular, the proposed algorithm resembles a distributional version of EM algorithm, consisting of an E-step of updating the latent variable variational distribution and an M-step of conducting steepest descent over the variational distribution of parameters. Our theoretical analysis relies on optimal transport theory and subdifferential calculus in the space of probability measures. We prove the exponential convergence of the time-discretized WGF for minimizing a generic objective functional given strict convexity along generalized geodesics. We also provide a new proof of the exponential contraction of the variational distribution obtained from the MF approximation by using the fixed-point equation of the time-discretized WGF. We apply our method and theory to two classic Bayesian latent variable models, the Gaussian mixture model and the mixture of regression model. Numerical experiments are also conducted to compliment the theoretical findings under these two models.
READ FULL TEXT