\definecolor{cb1}{RGB}{76,114,176} \definecolor{cb2}{RGB}{221,132,82} \definecolor{cb3}{RGB}{85,168,104} \definecolor{cb4}{RGB}{196,78,82} \definecolor{cb5}{RGB}{129,114,179} \definecolor{cb6}{RGB}{147,120,96} \definecolor{cb7}{RGB}{218,139,195} \newcommand{\abs}[1]{\left\lvert #1 \right\rvert} \newcommand{\norm}[1]{\left\lVert #1 \right\rVert} \DeclareMathOperator*{\argmin}{argmin} \DeclareMathOperator{\bigO}{\mathcal{O}} \DeclareMathOperator{\D}{\mathcal{D}} \DeclareMathOperator*{\E}{\mathbb{E}} \DeclareMathOperator*{\Var}{Var} \DeclareMathOperator{\mean}{mean} \DeclareMathOperator{\sign}{sign} \DeclareMathOperator{\spn}{span} \newcommand{\tp}{^\mathsf{T}} \DeclareMathOperator{\tr}{tr} \newcommand{\ud}{\mathrm{d}} \DeclareMathOperator{\W}{\mathcal{W}} \DeclareMathOperator{\KL}{KL} \DeclareMathOperator{\JS}{JS} \DeclareMathOperator{\mmd}{MMD} \DeclareMathOperator{\smmd}{SMMD} \DeclareMathOperator{\hsic}{HSIC} \DeclareMathOperator{\mmdhat}{\widehat{MMD}} \DeclareMathOperator{\hsichat}{\widehat{HSIC}} \newcommand{\optmmd}[1][\Psic]{\operatorname{\mathcal{D}_\mathrm{MMD}^{#1}}} \DeclareMathOperator{\optmmdhat}{\hat{\mathcal{D}}_\mathrm{MMD}} \newcommand{\optsmmdp}{\operatorname{\mathcal{D}_\mathrm{SMMD}}} \newcommand{\optsmmd}[1][\Psic]{\operatorname{\mathcal{D}_\mathrm{SMMD}^{\SS,#1,\lambda}}} \newcommand{\ktop}{k_\mathrm{top}} \newcommand{\lip}{\mathrm{Lip}} \newcommand{\cH}{\mathcal{H}} \newcommand{\h}{\mathcal{H}} \newcommand{\R}{\mathbb{R}} \newcommand{\cP}[1]{{\color{cb1} #1}} \newcommand{\PP}{\cP{\mathbb P}} \newcommand{\pp}{\cP{p}} \newcommand{\X}{\cP{X}} \newcommand{\Xp}{\cP{X'}} % kexpfam colors \newcommand{\nc}{{\color{#d62728}{n}}} \newcommand{\Xc}{{\color{#d62728}{X}}} \newcommand{\Mc}{{\color{#1f77b4}{M}}} \newcommand{\Yc}{{\color{#1f77b4}{Y}}} \newcommand{\mc}{{\color{#17becf}{m}}} \newcommand{\dc}{{\color{#2ca02c}{d}}} \newcommand{\Pdata}{\cP{\mathbb{P}_\mathrm{data}}} \newcommand{\cQ}[1]{{\color{cb2} #1}} \newcommand{\QQ}{\cQ{\mathbb Q}} \newcommand{\qq}{\cQ{q}} \newcommand{\qtheta}{\cQ{q_\theta}} \newcommand{\Y}{\cQ{Y}} \newcommand{\Yp}{\cQ{Y'}} \newcommand{\thetac}{\cQ{\theta}} \newcommand{\vtheta}{\thetac} \newcommand{\Qtheta}{\QQ_\thetac} \newcommand{\Gtheta}{\cQ{G_\theta}} \newcommand{\cZ}[1]{{\color{cb5} #1}} \newcommand{\Z}{\cZ{Z}} \newcommand{\Zc}{\cZ{\mathcal Z}} \newcommand{\ZZ}{\cZ{\mathbb Z}} \newcommand{\cpsi}[1]{{\color{cb3} #1}} \newcommand{\psic}{\cpsi{\psi}} \newcommand{\Psic}{\cpsi{\Psi}} \newcommand{\Dpsi}{\cpsi{D_\psi}} \newcommand{\SS}{\cpsi{\mathbb{S}}} \newcommand{\Xtilde}{\cpsi{\tilde{X}}} \newcommand{\Xtildep}{\cpsi{\tilde{X}'}}

Better deep learning (sometimes)
by learning kernel mean embeddings

Danica J. Sutherland(she/her)
University of British Columbia (UBC) / Alberta Machine Intelligence Institute (Amii)
slides (but not the talk) about four related projects:
Feng Liu
Wenkai Xu
Hsiao-Yu (Fish) Tung
(+ 6 more…)
Yazhe Li
Roman Pogodin
Michael Arbel
Mikołaj Bińkowski
Li (Kevin) Wenliang
Heiko Strathmann
+ all with Arthur Gretton

LIKE22 - 12 Jan 2022

(Swipe or arrow keys to move through slides; m for a menu to jump; ? to show more.)

Deep learning and deep kernels

  • Deep learning: models usually of form f(x) = w\tp \phi_\psic(x)
    • With a learned \phi_\psic(x) : \mathcal X \to \R^{D}
  • If we fix \psic , have f \in \cH_\psic with k_\psic(x, y) = \phi_\psic(x)\tp \phi_\psic(y)
    • Same idea as NNGP approximation
  • Could train a classifier by:
    • Let \tilde L(\psic) = L(f_\psic^*) , loss of the best f_\psic^* \in H_\psic
    • Learn \psic by following \nabla_\psic \tilde L(\psic) = \nabla_\psic L(f^*_\psic)
  • Generalize to a deep kernel: k_\psic(x, y) = \kappa\left( \phi_\psic(x), \phi_\psic(y) \right)

Normal deep learning \subset deep kernels

  • Take \phi_\psic(x) \in \R as output of last layer
  • k_\psic(x, y) = \phi_\psic(x) \phi_\psic(y) \fragment[1]{+ 1}
  • Final function in \cH_\psic will be a \phi_\psic(x) \fragment[1]{+ b}
  • With logistic loss: this is Platt scaling

So what?

  • This definitely does not say that deep learning is (even approximately) a kernel method
  • …despite what some people might want you think
  • We know theoretically deep learning can learn some things faster than any kernel method [see Malach+ ICML-21 + refs]
  • But deep kernel learning ≠ traditional kernel models
    • exactly like how usual deep learning ≠ linear models

What do deep kernels give us?

  • In “normal” classification: slightly richer function space 🤷
  • Meta-learning: common \phi_\psic , constantly varying f_\psic^*
  • Two-sample testing
    • Simple form of f_\psic^* for cheap permutation testing
  • Self-supervised learning
    • Better understanding of what's really going on, at least
  • Generative modeling with MMD GANs
    • Better gradient for generator to follow (?)
  • Score matching in exponential families (density estimation)
    • Optimize regularization weights, better gradient (?)

Maximum Mean Discrepancy (MMD)

\begin{align} \mmd&_k(\PP, \QQ) = \sup_{\lVert f \rVert_\cH \le 1} \E_{\X \sim \PP}[f(\X)] - \E_{\Y \sim \QQ}[f(\Y)] \\&\fragment[1]{ = \sup_{\lVert f \rVert_\cH \le 1} \E_{\X \sim \PP}[\langle f, \varphi(\X) \rangle_\cH] - \E_{\Y \sim \QQ}[\langle f, \varphi(\Y) \rangle_\cH] } \\&\fragment[2]{ = \sup_{\lVert f \rVert_\cH \le 1} \left\langle f, \E_{\X \sim \PP}[\varphi(\X)] - \E_{\Y \sim \QQ}[\varphi(\Y)] \right\rangle_\cH } \\&\fragment[3]{= %\left\lVert % \E_{\X \sim \PP}[\varphi(\X)] %- \E_{\Y \sim \QQ}[\varphi(\Y)] %\right\rVert_\cH \sup_{\lVert f \rVert_\cH \le 1} \left\langle f, \mu^k_\PP - \mu^k_\QQ \right\rangle_\cH } \fragment[4]{= \left\lVert \mu^k_\PP - \mu^k_\QQ \right\rVert_\cH } %\\ % \fragment[5]{\langle \mu_\PP^k, \,} % & % \fragment[5]{% % \mu_\QQ^k \rangle_\cH % = \E_{\substack{\X \sim \PP\\\Y \sim \QQ}} \langle \varphi(\X), \varphi(\Y) \rangle_\cH % = \E_{\substack{\X \sim \PP\\\Y \sim \QQ}} k(\X, \Y) % } \end{align}

MMD as feature matching

\mmd_k(\PP, \QQ) = \left\lVert \E_{\X \sim \PP}[ \varphi(\X) ] - \E_{\Y \sim \QQ}[ \varphi(\Y) ] \right\rVert_{\cH}

  • \varphi : \Xc \to \cH is the feature map for k(x, y) = \langle \varphi(x), \varphi(y) \rangle
  • If k(x, y) = x\tp y , \varphi(x) = x , then
    the MMD is distance between means
  • Many kernels: infinite-dimensional \cH

Estimating MMD

\begin{gather} \mmd_k^2(\PP, \QQ) % = \E_{\substack{\X, \Xp \sim \PP\\\Y, \Yp \sim \QQ}}\left[ % k(\X, \Xp) % - 2 k(\X, \Y) % + k(\Y, \Yp) % \right] = \E_{\X, \Xp \sim \PP}[k(\X, \Xp)] + \E_{\Y, \Yp \sim \QQ}[k(\Y, \Yp)] - 2 \E_{\substack{\X \sim \PP\\\Y \sim \QQ}}[k(\X, \Y)] \\ \fragment[0]{ \mmdhat_k^2(\X, \Y) = \fragment[1][highlight-current-red]{\mean(K_{\X\X})} + \fragment[2][highlight-current-red]{\mean(K_{\Y\Y})} - 2 \fragment[3][highlight-current-red]{\mean(K_{\X\Y})} } \end{gather}

K_{\X\X}

1.00.20.6
0.21.00.5
0.60.51.0

K_{\Y\Y}

1.00.80.7
0.81.00.6
0.70.61.0

K_{\X\Y}

0.30.10.2
0.20.30.3
0.20.10.4

I: Two-sample testing

  • Given samples from two unknown distributions \X \sim \PP \qquad \Y \sim \QQ
  • Question: is \PP = \QQ ?
  • Hypothesis testing approach: H_0: \PP = \QQ \qquad H_1: \PP \ne \QQ
  • Reject H_0 if test statistic \hat T(\X, \Y) > c_\alpha
  • Do smokers/non-smokers get different cancers?
  • Do Brits have the same friend network types as Americans?
  • When does my laser agree with the one on Mars?
  • Are storms in the 2000s different from storms in the 1800s?
  • Does presence of this protein affect DNA binding? [MMDiff2]
  • Do these dob and birthday columns mean the same thing?
  • Does my generative model \Qtheta match \Pdata ?
  • Independence testing: is P(\X, \Y) = P(\X) P(\Y) ?

What's a hypothesis test again?

2022-01-11T16:33:01.316366 image/svg+xml Matplotlib v3.5.1, https://matplotlib.org/

Permutation testing to find c_\alpha

Need \Pr_{H_0}\left( T(\X, \Y) > c_\alpha \right) \le \alpha

X_1
X_2
X_3
X_4
X_5
Y_1
Y_2
Y_3
Y_4
Y_5

c_\alpha : 1-\alpha th quantile of \left\{ \fragment[1]{\hat T(\cP{\tilde X_1}, \cQ{\tilde Y_1}), \;} \fragment[2]{\hat T(\cP{\tilde X_2}, \cQ{\tilde Y_2}), \;} \fragment[3]{\cdots} \right\}

MMD-based tests

  • If k is characteristic, \mmd(\PP, \QQ) = 0 iff \PP = \QQ
  • Efficient permutation testing for \mmdhat(\X, \Y)
    • H_0 : n \mmdhat^2 converges in distribution
    • H_1 : \sqrt n ( \mmdhat^2 - \mmd^2 ) asymptotically normal
  • Any characteristic kernel gives consistent test…eventually
  • Need enormous n if kernel is bad for problem

Classifier two-sample tests

  • \hat T(\X, \Y) is the accuracy of f on the test set
  • Under H_0 , classification impossible: \hat T \sim \mathrm{Binomial}(n, \frac12)
  • With k(x, y) = \frac14 f(x) f(y) where f(x) \in \{-1, 1\} ,
    get \mmdhat(\X, \Y) = \left\lvert \hat{T}(\X, \Y) - \frac12 \right\rvert

Optimizing test power

  • Asymptotics of \mmdhat^2 give us immediately that \Pr_{H_1}\left( n \mmdhat^2 > c_\alpha \right) \approx \Phi\left( \frac{\sqrt n \mmd^2}{\sigma_{H_1}} - \frac{c_\alpha}{\sqrt n \sigma_{H_1}} \right) \mmd , \sigma_{H_1} , c_\alpha are constants: first term dominates
  • Pick k to maximize an estimate of \mmd^2 / \sigma_{H_1}
  • Can show uniform \mathcal O_P(n^{-\frac13}) convergence of estimator

Blobs dataset

Blobs kernels

Blobs results

CIFAR-10 vs CIFAR-10.1

Train on 1 000, test on 1 031, repeat 10 times. Rejection rates:
MESCFC2STMMD-OMMD-D
0.5880.1710.4520.3160.744

Ablation vs classifier-based tests

Cross-entropyMax power
DatasetSignLinOursSignLinOurs
Blobs0.840.940.900.950.99
High- d Gauss. mix.0.470.590.290.640.66
Higgs0.260.400.350.300.40
MNIST vs GAN0.650.710.800.941.00

But…

  • What if you don't have much data for your testing problem?
  • Need enough data to pick a good kernel
  • Also need enough test data to actually detect the difference
  • Best split depends on best kernel's quality / how hard to find
    • Don't know that ahead of time; can't try more than one

Meta-testing

  • One idea: what if we have related problems?
  • Similar setup to meta-learning:(from Wei+ 2018)

Meta-testing for CIFAR-10 vs CIFAR-10.1

  • CIFAR-10 has 60,000 images, but CIFAR-10.1 only has 2,031
  • Where do we get related data from?
  • One option: set up tasks to distinguish classes of CIFAR-10 (airplane vs automobile, airplane vs bird, ...)

One approach (MAML-like)

A_\theta is, e.g., 5 steps of gradient descent

we learn the initialization, maybe step size, etc

This works, but not as well as we'd hoped…
Initialization might work okay on everything, not really adapt

Another approach: Meta-MKL

Inspired by classic multiple kernel learning

Only need to learn linear combination \beta_i on test task:
much easier

Theoretical analysis for Meta-MKL

  • Same big-O dependence on test task size 😐
  • But multiplier is much better:
    based on number of meta-training tasks, not on network size
  • (Analysis assumes meta-tasks are “related” enough)

Results on CIFAR-10.1

Challenges for testing

  • When \PP \ne \QQ , can we tell how they're different?
    • Methods so far: some mostly for low- d
    • Some look at points with large critic function
  • Finding kernels / features that can't do certain things
    • distinguish by emotion, but can't distinguish by skin color
  • Avoid the need for data splitting (selective inference)
    • Kübler+ NeurIPS-20 gave one method
    • only for multiple kernel learning
    • only with data-inefficient (streaming) estimator

II: Self-supervised learning

Given a bunch of unlabeled samples \X ,
want to find “good” features \Z = f(\X)
(e.g. so that a linear classifier on \Z works with few samples)

One common approach: contrastive learning

InfoNCE [van den Oord+ 2018, Poole+ ICML-19]

\begin{align*} \E_{\Z_1}\left[ \E_{\Z_2 \sim \text{pos}}[ k(\Z_1, \Z_2)]] + \log \E_{\Z_2}[ \exp(k(\Z_1, \Z_2)) \right] \\\le \operatorname{MI}(\Z_1, \Z_2) \end{align*}
Variants:
CPC, SimCLR, MoCo, SwAV, …

Mutual information isn't why SSL works!

InfoNCE approximates MI between "positive" views

But MI is invariant to transformations important to SSL!

Hilbert-Schmidt Independence Criterion (HSIC)

\begin{align*} \hsic(\X, \Y) &= \norm{ \E[ \phi(\X) \otimes \phi(\Y) ] - \E[ \phi(\X) ] \otimes \E[ \phi(\Y) ] }_{HS}^2 \\&= \mmd^2(\PP_{\X\Y}, \PP_\X \otimes \PP_\Y) \\&\fragment{\le C_k \operatorname{MI}(\X, \Y) \qquad\text{($C_k$ depends only on $\norm{k}_\infty$)}} \end{align*}

With a linear kernel: \hsic = \norm{ \E[\X \Y\tp] - \E[\X] \E[\Y]\tp }_F^2

Estimator based on kernel matrices: \hsichat = \frac{1}{(n-1)^2} \tr(\cP{K} H \cQ{L} H)
  • H is the centering matrix
  • \cP{K} is the kernel matrix on \X
  • \cP{L} is the kernel matrix on \Y

SSL-HSIC

\mathcal L_{\text{SSL-HSIC}} = -\hsic(\Z, \Y) + \gamma \sqrt{\hsic(\Z, \Z)}
( \Y is just an indicator of which source image it came from)

Target representation \Z' is output of f_\theta(\X)

HSIC uses learned kernel k(\Z'_1, \Z'_2) = \operatorname{IMQ}(g(\Z'_1), g(\Z'_2))

Your InfoNCE model is secretly a kernel method…

\begin{gather*} \mathcal{L}_{\mathrm{InfoNCE}}(\theta) = - \E_{(\Z_1,\Z_2) \sim \mathrm{pos}}\left[ k(\Z_1, \Z_{2}) \right] + \E_{\Z_1} \log \E_{\Z_2} \left[ \exp k(\Z_{1}, \Z_{2}) \right] \qquad\\\qquad \fragment{ \approx \underbrace{% - \E_{z_1,z_2\sim \mathrm{pos}}\left[k(\Z_1, \Z_2)\right] + \E_{\Z_1} \E_{\Z_2} \left[k(\Z_1, \Z_2) \right]% }_{\propto -\hsic(\Z, \Y)} + \frac12 \underbrace{% \E_{\Z_1} \left[\Var_{\Z_2} \left[ k(\Z_1, \Z_2) \right] \right]% }_{\textrm{variance penalty}} } \end{gather*}

Very similar loss! Just different regularizer

When variance is small,
-\hsic(\Z, \Y) + \gamma \hsic(\Z, \Z) \le \mathcal L_{\mathrm{InfoNCE}} + o(\mathrm{variance})

Clustering interpretation

SSL-HSIC estimates agreement of \Z with cluster structure of \Y

With linear kernels: -\hsic(\Z, \Y) \propto \sum_{i=1}^n \sum_{p=1}^m \norm{\cZ{Z_i^{(p)}} - \cZ{\bar Z_i}}^2 - n m where \cZ{\bar Z_i} is mean of the m augmentations

Resembles BYOL loss with no target network (but still works!)

ImageNet results: linear evaluation

Transfer from ImageNet to classification tasks

III: Training implicit generative models

Given samples from a distribution \PP over \mathcal X ,
we want a model that can produce new samples from \Qtheta \approx \PP


\X \sim \PP

\Y \sim \Qtheta
“Everybody Dance Now” [Chan+ ICCV-19]

Generator networks

Fixed distribution of latents: \Z \sim \mathrm{Uniform}\left( [-1, 1]^{100} \right)

Maps through a network: G_\thetac(\Z) \sim \Qtheta

DCGAN generator [Radford+ ICLR-16]

How to choose \thetac ?

GANs and their flaws

  • GANs [Goodfellow+ NeurIPS-14] minimize discriminator accuracy (like classifier test) between \PP and \Qtheta
  • Problem: if there's a perfect classifier, discontinuous loss, no gradient to improve it [Arjovsky/Bottou ICLR-17]
  • Disjoint at init:
    \PP :
    \cQ{\QQ_{\theta_0}} :
  • For usual \Gtheta : \R^{100} \to \R^{64 \times 64 \times 3} , \Qtheta is supported on a countable union of manifolds with dim \le 100
  • “Natural image manifold” usually considered low-dim
  • Won't align at init, so won't ever align

WGANs and MMD GANs

  • Integral probability metrics with “smooth” \mathcal F are continuous
  • WGAN: \mathcal F a set of neural networks satisfying \norm{f}_L \le 1
  • WGAN-GP: instead penalize \E \norm{\nabla_x f(x)} near the data
  • Both losses are MMD with k_\psic(x, y) = \phi_\psic(x) \phi_\psic(y)
    • \min_\theta \left[ \optmmd(\PP, \Qtheta) = \sup_{\psic \in \Psic} \mmd_{\psic}(\PP, \Qtheta) \right]
  • Some kind of constraint on \phi_\psic is important!

Non-smoothness of plain MMD GANs

Illustrative problem in \R , DiracGAN [Mescheder+ ICML-18]:

2022-01-11T16:33:02.905103 image/svg+xml Matplotlib v3.5.1, https://matplotlib.org/
2022-01-11T16:33:04.064985 image/svg+xml Matplotlib v3.5.1, https://matplotlib.org/
  • Just need to stay away from tiny bandwidths \psi
  • …deep kernel analogue is hard.
  • Instead, keep witness function from being too steep
  • \sup_x \lVert \nabla f(x) \rVert would give Wasserstein
    • Nice distance, but hard to estimate
  • Control \lVert \nabla f(\Xtilde) \rVert on average, near the data

MMD-GAN with gradient control

  • If \Psic gives uniformly Lipschitz critics, \optmmd is smooth
  • Original MMD-GAN paper [Li+ NeurIPS-17]: box constraint
  • We [Bińkowski+ ICLR-18] used gradient penalty on critic instead
    • Better in practice, but doesn't fix the Dirac problem…

New distance: Scaled MMD

Want to ensure \E_{\Xtilde \sim \SS}[ \lVert \nabla f(\Xtilde) \rVert^2 ] \le 1

Can solve with \langle \partial_i \phi(x), f \rangle_{\cH} = \partial_i f(x) …but too expensive!

Guaranteed if \lVert f \rVert_\cH \le \sigma_{\SS,k,\lambda}

\sigma_{\SS,k,\lambda} := \left( \lambda + \E_{\Xtilde \sim \SS}\left[ k(\Xtilde, \Xtilde) + [\nabla_1 \!\cdot\! \nabla_2 k](\Xtilde, \Xtilde) %+ \sum_{i=1}^d \frac{\partial^2 k(y, z)}{\partial y_i \partial z_i} \Bigg\rvert_{(y,z) = (\Xtilde, \Xtilde)} \right] \right)^{-\frac12}

Gives distance \smmd_{\SS,k,\lambda}(\PP, \QQ) = \sigma_{\SS,k,\lambda} \mmd_k(\PP, \QQ)

\begin{align} \optmmd \text{ has } & \mathcal F = \bigcup_{\psic \in \Psic} \left\{ f : \lVert f \rVert_{\h_{\psic}} \le 1 \right\} \\ \optsmmd \text{ has } & \mathcal F = \bigcup_{\psic \in \Psic} \left\{ f : \lVert f \rVert_{\h_{\psic}} \le \sigma_{\SS,k,\lambda} \right\} \end{align}

Deriving the Scaled MMD

\begin{gather} \fragment[0]{\E_{\Xtilde \sim \SS}[ f(\Xtilde)^2 ] + } \E_{\Xtilde \sim \SS}[ \lVert \nabla f(\Xtilde) \rVert^2 ] \fragment[0]{+ \lambda \lVert f \rVert_\h^2} \le 1 \\ \fragment[1]{ \E_{\Xtilde \sim \SS}[ f(\Xtilde)^2 ] = \left\langle f, \E_{\Xtilde \sim \SS}\left[ k(\Xtilde, \cdot) \otimes k(\Xtilde, \cdot) \right] f \right\rangle } \\ \fragment[2]{ \E_{\Xtilde \sim \SS}[ \lVert \nabla f(\Xtilde) \rVert^2 ] = \left\langle f, \E_{\Xtilde \sim \SS}\left[ \sum_{i=1}^d \partial_i k(\Xtilde, \cdot) \otimes \partial_i k(\Xtilde, \cdot) \right] f \right\rangle } \\ \fragment[3]{ \langle f, D_\lambda f \rangle \le \lVert D_\lambda \rVert \, \lVert f \rVert_\h^2 \le \sigma_{\SS,k,\lambda}^{-2} \lVert f \rVert_\h^2 } \end{gather}

Smoothness of \D_\mathrm{SMMD}

2022-01-11T16:33:05.320144 image/svg+xml Matplotlib v3.5.1, https://matplotlib.org/

Theorem: \optsmmd is continuous.

If \SS has a density; \ktop is Gaussian/linear/…;
\phi_\psic is fully-connected, Leaky-ReLU, non-increasing width;
all weights in \Psic have bounded condition number; then

\W(\QQ_n, \PP) \to 0 \text{ implies } \optsmmd(\QQ_n, \PP) \to 0 .

Results on 160 \times 160 CelebA

SN-SMMD-GANKID: 0.006
WGAN-GPKID: 0.022

Evaluating generative models

Evaluating generative models

  • Human evaluation: good at precision, bad at recall
  • Likelihood: hard for GANs, maybe not right thing anyway
  • Two-sample tests: always reject!
  • Most common: Fréchet Inception Distance, FID
    • Run pretrained featurizer on model and target
    • Model each as Gaussian; compute W_2
    • Strong bias, small variance: very misleading
    • Simple examples where \operatorname{FID}(\cQ{\QQ_1}) > \operatorname{FID}(\cQ{\QQ_2}) but \widehat{\operatorname{FID}}(\cQ{\hat\QQ_1}) < \widehat{\operatorname{FID}}(\cQ{\hat\QQ_2}) for reasonable sample size
  • Our KID: \mmd^2 instead. Unbiased, asymptotically normal

Training process on CelebA

image/svg+xml

IV: Unnormalized density/score estimation

  • Problem: given samples \X_i \sim \PP_0 with density \pp_0
  • Model is kernel exponential family: for any f \in \cH , \begin{align} \pp_f(x) &= \exp\left( f(x) \right) \; q(x) / Z(f) \\&= \exp\big( \big\langle \underbrace{f}_{\substack{\text{natural}\\\text{parameter}}}, \underbrace{\phi(x)}_{\substack{\text{sufficient}\\\text{statistic}}} \big\rangle_\cH \big) \underbrace{q(x)}_{\substack{\text{base}\\\text{measure}}} / \underbrace{Z(f)}_{\text{normalizer}} \end{align} i.e. any density with \log \pp - \log q \in \cH
  • Gaussian k : dense in all continuous distributions on compact domains

Density estimation with KEFs

  • Fitting with maximum likelihood is tough:
    • Z(f) , \nabla Z(f) are tough to compute
    • Likelihood equations ill-posed for characteristic kernels
  • We choose to fit the unnormalized model
    • Could then estimate Z(f) once after fitting if necessary

Unnormalized density / score estimation

  • Don't necessarily need to compute Z(f) afterwards
  • f + \log q = \log p_f + \log Z(f) , the “energy”, lets us:
    • Find modes (global or local)
    • Sample (with MCMC)
  • The score, \nabla_x[ f(x) + \log q(x) ] = \nabla_x \log p_f(x) , lets us:
    • Run HMC for targets whose gradients we can't evaluate
    • Construct Monte Carlo control functionals

Score matching in KEFs [Sriperumbudur+ JMLR-17]

  • Idea: minimize Fisher divergence J(p_0 \| p_f) J(f) = \frac12 \int p_0(x) \norm{\nabla_x \log p_f(x) - \nabla_x \log p_0(x)}^2 \ud x
  • Under mild assumptions, J(f) = C(p_0) + {} \int p_0(x) \sum_{d=1}^D \left[ \partial_d^2 \log p_f(x) + \frac12 (\partial_d \log p_f(x)))^2 \right] \ud x
  • Can estimate with Monte Carlo

Score matching in KEFs [Sriperumbudur+ JMLR-17]

  • Minimize regularized loss function: \hat J_\lambda(f) = \frac1n \sum_{a=1}^n \sum_{i=1}^d \left[ \partial_i^2 f(X_a) + \frac12 \left( \partial_i f(X_a) \right)^2 \right] + \frac12 \lambda \lVert f \rVert_\h^2
  • Representer theorem tells us minimizer of \hat J_\lambda over \h is f_{\lambda,\mathcal X} \in \fragment[2][highlight-current-red]{\spn\left\{ \partial_i k_{X_a} \right\}_{a \in [n]}^{i \in [d]}} \cup \fragment[3][highlight-current-red]{\spn\left\{ \partial_i^2 k_{X_a} \right\}_{a \in [n]}^{i \in [d]}}

Score matching in a subspace

  • Best f \in \h is in {\color{#8172b2}{\h_\text{full}}} = \spn\left\{ \partial_i k_{X_a} \right\}_{a \in [\nc]}^{i \in [\dc]} \cup \spn\left\{ \partial_i^2 k_{X_a} \right\}_{a \in [\nc]}^{i \in [\dc]}
  • Find best f in dim \Mc subspace in \bigO(\nc\dc \Mc^2 + \Mc^3) time \beta = - (\tfrac1n \underbrace{B_{\Xc \Yc}\tp}_{\Mc \times \nc\dc} \; \underbrace{B^\phantom{\mathsf T}_{\Xc \Yc}}_{\nc\dc \times \Mc} + \lambda \underbrace{G_{\Yc \Yc}}_{\Mc \times \Mc} )^\dagger \underbrace{h_{\Yc}}_{\Mc \times 1}
    • {\color{#8172b2}{\h_\text{full}}} : \Mc = 2 \nc \dc , \mathcal{O}(\nc^3 \dc^3) time!

Nyström approximation [Sutherland+ AISTATS-18]

  • {\color{#8172b2}{\h_\text{full}}} = \spn\left\{ \partial_i k_{X_a} \right\}_{a \in [\nc]}^{i \in [\dc]} \cup \spn\left\{ \partial_i^2 k_{X_a} \right\}_{a \in [\nc]}^{i \in [\dc]}
  • Nyström approximation: find fit in different (smaller) \h_{\Yc}
  • One choice: pick \Yc \subset [\nc] , \lvert \Yc \rvert = \mc at random, then
    {\color{#ccb974}{\h_\text{nys}^Y}} = \spn\left\{ \partial_i k_{X_a} \right\}_{a \in \Yc}^{i \in [\dc]} \qquad \mathcal{O}(\nc \mc^2 \dc^3)\text{ time}
    Get the same rates with \mc = \sqrt \nc \log \nc (sometimes less)
  • “lite”: pick \Yc at random, then
    {\color{#64b5c2}{\h_\text{lite}^Y}} = \spn\left\{ k_{X_a} \right\}_{a \in \Yc} \qquad \mathcal{O}(\nc \mc^2 \dc)\text{ time}

Meta-learning a kernel

  • This was all with a fixed kernel and \lambda
  • Good results need carefully tuned kernel and \lambda
  • We can use a deep kernel as long as we split the data
    • Otherwise it would always pick bandwidth \to 0

Results

  • Learns local dataset geometry: better fits
  • On real data: slightly worse likelihoods, maybe better “shapes” than deep likelihood models

Recap

Combining a deep architecture with a kernel machine that takes the higher-level learned representation as input can be quite powerful.

— Y. Bengio & Y. LeCun (2007), “Scaling Learning Algorithms towards AI

  • Two-sample testing [ICLR-17, ICML-20, NeurIPS-21]
    • \psic maximizing power criterion, for one task or many
  • Self-supervised learning with HSIC [NeurIPS-21]
    • Much better understanding of what's going on!
  • Generative modeling with MMD GANs [ICLR-18, NeurIPS-18]
    • Need a smooth loss function for the generator
  • Score matching in exponential families [AISTATS-18, ICML-19]
    • Avoid overfitting with closed-form fit on held-out data