01 Discrete Time

$$ \newcommand{aster}{*} \newcommand{exist}{\exists} \newcommand{B}{\mathbb B} \newcommand{C}{\mathbb C} \newcommand{I}{\mathbb I} \newcommand{N}{\mathbb N} \newcommand{Q}{\mathbb Q} \newcommand{R}{\mathbb R} \newcommand{Z}{\mathbb Z} \newcommand{eR}{\overline {\mathbb R}} \newcommand{cD}{ {\mathbb D}} \newcommand{dD}{ {\part \mathbb D}} \newcommand{dH}{ {\part \mathbb H}} \newcommand{eC}{\overline {\mathbb C}} \newcommand{A}{\mathcal A} \newcommand{D}{\mathcal D} \newcommand{E}{\mathcal E} \newcommand{F}{\mathcal F} \newcommand{G}{\mathcal G} \newcommand{H}{\mathcal H} \newcommand{J}{\mathcal J} \newcommand{L}{\mathcal L} \newcommand{U}{\mathcal U} \newcommand{M}{\mathcal M} \newcommand{O}{\mathcal O} \newcommand{P}{\mathcal P} \newcommand{S}{\mathcal S} \newcommand{T}{\mathcal T} \newcommand{V}{\mathcal V} \newcommand{W}{\mathcal W} \newcommand{X}{\mathcal X} \newcommand{Y}{\mathcal Y} \newcommand{bE}{\symbf E} \newcommand{bF}{\symbf F} \newcommand{bD}{\symbf D} \newcommand{bI}{\symbf I} \newcommand{bX}{\symbf X} \newcommand{bY}{\symbf Y} \newcommand{nz}{\mathcal Z} \newcommand{bT}{\mathbb T} \newcommand{bB}{\mathbb B} \newcommand{bS}{\mathbb S} \newcommand{bA}{\mathbb A} \newcommand{bL}{\mathbb L} \newcommand{bP}{\symbf P} \newcommand{bM}{\symbf M} \newcommand{bH}{\mathbb H} \newcommand{dd}{\mathrm d} \newcommand{Mu}{\mathup M} \newcommand{Tau}{\mathup T} \newcommand{ae}{\operatorname{a.e.}} \newcommand{aut}{\operatorname{aut}} \newcommand{adj}{\operatorname{adj}} \newcommand{char}{\operatorname{char}} \newcommand{cov}{\operatorname{Cov}} \newcommand{cl}{\operatorname{cl}} \newcommand{cont}{\operatorname{cont}} \newcommand{e}{\mathbb E} \newcommand{pp}{\operatorname{primitive}} \newcommand{dist}{\operatorname{dist}} \newcommand{diam}{\operatorname{diam}} \newcommand{fp}{\operatorname{Fp}} \newcommand{from}{\leftarrow} \newcommand{Gal}{\operatorname{Gal}} \newcommand{GCD}{\operatorname{GCD}} \newcommand{LCM}{\operatorname{LCM}} \newcommand{fg}{\mathrm{fg}} \newcommand{gf}{\mathrm{gf}} \newcommand{im}{\operatorname{Im}} \newcommand{image}{\operatorname{image}} \newcommand{inj}{\hookrightarrow} \newcommand{irr}{\operatorname{irr}} \newcommand{lcm}{\operatorname{lcm}} \newcommand{ltrieq}{\mathrel{\unlhd}} \newcommand{ltri}{\mathrel{\lhd}} \newcommand{loc}{ {\operatorname{loc}}} \newcommand{null}{\operatorname{null}} \newcommand{part}{\partial} \newcommand{pf}{\operatorname{Pf}} \newcommand{pv}{\operatorname{Pv}} \newcommand{rank}{\operatorname{rank}} \newcommand{range}{\operatorname{range}} \newcommand{re}{\operatorname{Re}} \newcommand{span}{\operatorname{span}} \newcommand{su}{\operatorname{supp}} \newcommand{sgn}{\operatorname{sgn}} \newcommand{syn}{\operatorname{syn}} \newcommand{var}{\operatorname{Var}} \newcommand{res}{\operatorname{Res}} \newcommand{data}{\operatorname{data}} \newcommand{erfc}{\operatorname{erfc}} \newcommand{erfcx}{\operatorname{erfcx}} \newcommand{tr}{\operatorname{tr}} \newcommand{col}{\operatorname{Col}} \newcommand{row}{\operatorname{Row}} \newcommand{sol}{\operatorname{Sol}} \newcommand{lub}{\operatorname{lub}} \newcommand{glb}{\operatorname{glb}} \newcommand{ltrieq}{\mathrel{\unlhd}} \newcommand{ltri}{\mathrel{\lhd}} \newcommand{lr}{\leftrightarrow} \newcommand{phat}{^\widehat{\,\,\,}} \newcommand{what}{\widehat} \newcommand{wbar}{\overline} \newcommand{wtilde}{\widetilde} \newcommand{iid}{\operatorname{i.i.d.}} \newcommand{Exp}{\operatorname{Exp}} \newcommand{abs}[1]{\left| {#1}\right|} \newcommand{d}[2]{D_{\text{KL}}\left (#1\middle\| #2\right)} \newcommand{n}[1]{\|#1\|} \newcommand{norm}[1]{\left\|{#1}\right\|} \newcommand{pd}[2]{\left \langle {#1},{#2} \right \rangle} \newcommand{argmax}[1]{\underset{#1}{\operatorname{argmax}}} \newcommand{argmin}[1]{\underset{#1}{\operatorname{argmin}}} \newcommand{p}[1]{\left({#1}\right)} \newcommand{c}[1]{\left \{ {#1}\right\}} \newcommand{s}[1]{\left [{#1}\right]} \newcommand{a}[1]{\left \langle{#1}\right\rangle} \newcommand{cc}[2]{\left(\begin{array}{c} #1 \\ #2 \end{array}\right)} \newcommand{f}{\mathfrak F} \newcommand{fi}{\mathfrak F^{-1}} \newcommand{Fi}{\mathcal F^{-1}} \newcommand{l}{\mathfrak L} \newcommand{li}{\mathfrak L^{-1}} \newcommand{Li}{\mathcal L^{-1}} \newcommand{const}{\text{const.}} $$

Discrete Time Normalizing Flows #

Papamakarios, G., Nalisnick, E.T., Jimenez Rezende, D., Mohamed, S., & Lakshminarayanan, B. (2019). Normalizing Flows for Probabilistic Modeling and Inference. J. Mach. Learn. Res., 22, 57:1-57:64.

Review: Change of variable TODO #

Suppose $f: \R^d \to \R^d$ is a non-degenerate $C^1$-diffeomorphism (or just smooth in this section).

  • $f$ is a bijection with inverse $f^{-1}$.
  • $f \in C^1(\R^d \to \R^d)$ and $f^{-1} \in C^1(\R^d \to \R^d)$.
  • $\forall x \in \R^d: \abs{\det f'(x)} > 0$.

Suppose $p _ X(x) \in \L^1(\R^d \to [0, \infty])$ is a probability density. Suppose on some probability space $X \sim p _ X(x)$.

Let $Y = f(X) \sim p _ Y(y)$. Suppose $y = f(x)$, then $$ p _ Y(y) = \frac{p _ X(f^{-1}(y))}{\abs{\det f'(f^{-1}(y))}}, \quad p _ X(x) = p _ Y(f(x)) \abs{\det f'(x)} $$ Take logarithm on both sides gives: $$ \begin{aligned} \log p _ Y(y) &= \log p _ X(f^{-1}(y)) - \log \abs{\det f'(f^{-1}(y))}\\ \log p _ X(x) &= \log p _ Y(f(x)) + \log \abs{\det f'(x)} \end{aligned} $$

Flows for generative modeling #

Suppose $Z \sim p _ {\psi}(z)$ where $p _ {\psi}(z) \in \L^1(\R^d \to [0, \infty])$ is a density function parameterized by $\psi \in \Psi$.

Suppose $f _ \theta: \R^d \to \R^d$ is a parameterized non-degenerate $C^1$-diffeomorphism. Let $X = f _ \theta(Z) \sim p _ {\theta, \psi}(x)$.

To sample from $p _ X(x)$:

  • First sample $z$ from $p _ Z(z)$.
  • Then apply transform $x = f _ \theta(z) \sim p _ X(x)$.

To evaluate $\log p _ X(x)$:

  • First compute $z = f _ \theta^{-1}(x)$.
  • Then compute $\log\abs{\det f _ \theta'(z)}$ and $\log p _ Z(z)$. Since $\log p _ X(x) = \log p _ Z(z) - \log \abs{\det f _ \theta'(z)}$.

Notice that efficient sampling, and log probability evaluation have different requirements.

  • Efficient sampling requires fast sampling from $p _ Z$ and fast transform $f _ \theta$.
  • Efficient evaluation requires fast inverse $f _ \theta^{-1}$ , fast evaluation of $\log p _ Z(z)$, and fast evaluation of $\log \abs{\det f _ \theta'(z)}$.

Suppose $X _ * \sim p _ * (x)$ is a data density on $\R^d$. Training with the forward KL is $$ \d{p _ * }{p _ X} = \int _ {\R^d} p _ * (x) \log \frac{p _ * (x)}{p _ X(x)} \dd x = - \int _ {\R^d} p _ * (x) \log p _ X(x) \dd x + \const $$

Flows for generative modeling: a normalizing transformation #

Continue the discussion above. Let $Z _ * = f _ \theta^{-1}(X _ * ) \sim p _ {Z _ * }(f _ \theta(z)) \abs{\det f _ \theta'(z)}$. Note that $Z \sim p _ Z(z) = p _ X(f _ \theta(z)) \abs{\det f _ \theta'(z)}$. $$ \begin{aligned} \d{p _ * }{p _ X} & = \int _ {\R^d} p _ * (x) \log \frac{p _ * (x)}{p _ X(x)} \dd x\\ &= \int _ {\R^d} p _ * (f(z)) \log \frac{p _ * (f(z))}{p _ X(f(z))} \abs{\det f'(z)} \dd z = \d{p _ {Z _ * }}{p _ Z} \end{aligned} $$ In the context of generative models, a flow model can be interpreted as normalizing the data distribution, and transform it into the tractable prior distribution $p _ Z$.

Flow transform of $f$-divergence #

Suppose $g: \R^d \to \R^d$ is a smooth function. Suppose $X _ 1 \sim p _ 1(x)$ and $X _ 2 \sim p _ 2(x)$ on $\R^d$. Suppose $Y _ 1 = g(X _ 1) \sim q _ 1(y)$, and $Y _ 2 = g(X _ 2) \sim q _ 2(y)$. Now observe the following: $$ \begin{aligned} D _ f(p _ 1 \Vert p _ 2) & = \int _ {\R^d} f\p{\frac{p _ 1(x)}{p _ 2(x)}} p _ 2(x) \dd x\\ & = \int _ {\R^d} f \p{\frac{q _ 1(g(x))\abs{\det g'(x)}}{q _ 2(g(x))\abs{\det g'(x)}}} q _ 2(g(x)) \abs{\det g'(x)} \dd x\\ & = \int _ {\R^d} f \p{\frac{q _ 1(y)}{q _ 2(y)}} q _ 2(y) \dd y = D _ f(q _ 1 \Vert q _ 2) \end{aligned} $$

Flows for variational inference #

Suppose $p(z, x)$ is a latent variable model, where $z \in \R^d$. Let the family of approximate posteriors $\mathcal Q$ in variational inference be a flow model.

Suppose $U \sim q _ \phi(u)$ is the prior distribution, and $\what Z = f _ \phi(U) \sim q _ {\what Z}(z)$ has a density parameterized by a flow model.

$$ \newcommand{\elbo}{\operatorname{ELBO}} \begin{aligned} \elbo(p, q _ \phi, x) & = E \s{\log \frac{p(\what Z, x)}{q _ \phi(\what Z)}} = - \d{q _ \phi(z)}{p(z|x)} + \log p(x)\\ & = \int _ {\R^d} \log \frac{p(z, x)}{q _ \phi(z)} q _ \phi(z) \dd z = \int _ {\R^d} \log \frac{p(f _ \phi(u), x)}{q _ \phi(f _ \phi(u))} q _ {\phi}(f _ \phi(u)) \abs{\det f _ \phi'(u)}\dd u\\ & = \int _ {\R^d} \log \frac{p(f _ \phi(u), x)}{q _ \phi(u)/\abs{\det f _ {\phi}'(u)}} q _ \phi(u) \dd u\\ & = E[\log p(f _ \phi(U), x)] + H(U) + E\s{\log \abs{\det f' _ \phi(U)}} \end{aligned} $$ Flow models in the context of variational inference can be thought of implementing a generalized reparameterization trick.

For efficient optimization of the ELBO, we need fast sampling of $U$ and fast evaluation of $\log \abs{\det f _ \phi'(u)}$.

Flows for enhancing approximate posteriors in VAEs #

Jimenez Rezende, D., & Mohamed, S. (2015). Variational Inference with Normalizing Flows. International Conference on Machine Learning.

Kingma, D.P., Salimans, T., Józefowicz, R., Chen, X., Sutskever, I., & Welling, M. (2016). Improving Variational Autoencoders with Inverse Autoregressive Flow. NIPS.

Similar to the case of variational inference, amortized variational inference in VAEs also benefits from flow transforms.

Suppose the approximate posterior is a density function $q _ {\phi}(z | x)$ on $\R^d$, and $\wtilde Z \sim q _ \phi(z | x)$.

Suppose $f _ \eta: \R^d \to \R^d$ is smooth. And $\what Z = f _ \eta(\wtilde Z) \sim q _ {\phi, \eta}(z|x)$.

$$ \begin{aligned} \elbo(p _ \theta, q _ {\phi, \eta}, x) & = E\s{\log \frac{p _ \theta(x | \what Z) p _ \theta(\what Z)}{q _ {\phi, \eta}(\what Z | x)}}\\ & = \log p _ \theta(x) - \boxed{\d{q _ {\phi, \eta}(z | x)}{p _ \theta(z | x)}}\\ & = E\s{\log p _ \theta(x, f _ \eta(\wtilde Z))} + H(\wtilde Z) - E\s{\log \abs{\det f' _ \eta(\wtilde Z)}} \end{aligned} $$

For efficiently doing stochastic gradient variational Bayes (SGVB) with the flow enhanced posterior, we need

  • efficient evaluation of $f _ \eta(\wtilde z)$ and
  • efficient evaluation of $\log \abs{\det f _ \eta'(\wtilde z)}$.

By transforming the KL-divergence with $f _ \eta^{-1}$, the KL-divergence can be interpreted as normalizing the true posterior to make it simpler.