\definecolor{cb1}{RGB}{76,114,176} \definecolor{cb2}{RGB}{221,132,82} \definecolor{cb3}{RGB}{85,168,104} \definecolor{cb4}{RGB}{196,78,82} \definecolor{cb5}{RGB}{129,114,179} \definecolor{cb6}{RGB}{147,120,96} \definecolor{cb7}{RGB}{218,139,195} \newcommand{\abs}[1]{\left\lvert #1 \right\rvert} \newcommand{\norm}[1]{\left\lVert #1 \right\rVert} \DeclareMathOperator*{\argmin}{argmin} \DeclareMathOperator{\bigO}{\mathcal{O}} \DeclareMathOperator{\D}{\mathcal{D}} \DeclareMathOperator*{\E}{\mathbb{E}} \DeclareMathOperator*{\Var}{Var} \DeclareMathOperator{\mean}{mean} \DeclareMathOperator{\sign}{sign} \DeclareMathOperator{\spn}{span} \newcommand{\tp}{^\mathsf{T}} \DeclareMathOperator{\tr}{tr} \newcommand{\ud}{\mathrm{d}} \DeclareMathOperator{\W}{\mathcal{W}} \DeclareMathOperator{\KL}{KL} \DeclareMathOperator{\JS}{JS} \DeclareMathOperator{\mmd}{MMD} \DeclareMathOperator{\smmd}{SMMD} \DeclareMathOperator{\hsic}{HSIC} \DeclareMathOperator{\mmdhat}{\widehat{MMD}} \DeclareMathOperator{\hsichat}{\widehat{HSIC}} \newcommand{\optmmd}[1][\Psic]{\operatorname{\mathcal{D}_\mathrm{MMD}^{#1}}} \DeclareMathOperator{\optmmdhat}{\hat{\mathcal{D}}_\mathrm{MMD}} \newcommand{\optsmmdp}{\operatorname{\mathcal{D}_\mathrm{SMMD}}} \newcommand{\optsmmd}[1][\Psic]{\operatorname{\mathcal{D}_\mathrm{SMMD}^{\SS,#1,\lambda}}} \newcommand{\ktop}{k_\mathrm{top}} \newcommand{\lip}{\mathrm{Lip}} \newcommand{\cH}{\mathcal{H}} \newcommand{\h}{\mathcal{H}} \newcommand{\R}{\mathbb{R}} \newcommand{\cP}[1]{{\color{cb1} #1}} \newcommand{\PP}{\cP{\mathbb P}} \newcommand{\pp}{\cP{p}} \newcommand{\X}{\cP{X}} \newcommand{\Xp}{\cP{X'}} % kexpfam colors \newcommand{\nc}{{\color{#d62728}{n}}} \newcommand{\Xc}{{\color{#d62728}{X}}} \newcommand{\Mc}{{\color{#1f77b4}{M}}} \newcommand{\Yc}{{\color{#1f77b4}{Y}}} \newcommand{\mc}{{\color{#17becf}{m}}} \newcommand{\dc}{{\color{#2ca02c}{d}}} \newcommand{\Pdata}{\cP{\mathbb{P}_\mathrm{data}}} \newcommand{\cQ}[1]{{\color{cb2} #1}} \newcommand{\QQ}{\cQ{\mathbb Q}} \newcommand{\qq}{\cQ{q}} \newcommand{\qtheta}{\cQ{q_\theta}} \newcommand{\Y}{\cQ{Y}} \newcommand{\Yp}{\cQ{Y'}} \newcommand{\thetac}{\cQ{\theta}} \newcommand{\vtheta}{\thetac} \newcommand{\Qtheta}{\QQ_\thetac} \newcommand{\Gtheta}{\cQ{G_\theta}} \newcommand{\cZ}[1]{{\color{cb5} #1}} \newcommand{\Z}{\cZ{Z}} \newcommand{\Zc}{\cZ{\mathcal Z}} \newcommand{\ZZ}{\cZ{\mathbb Z}} \newcommand{\cpsi}[1]{{\color{cb3} #1}} \newcommand{\psic}{\cpsi{\psi}} \newcommand{\Psic}{\cpsi{\Psi}} \newcommand{\Dpsi}{\cpsi{D_\psi}} \newcommand{\SS}{\cpsi{\mathbb{S}}} \newcommand{\Xtilde}{\cpsi{\tilde{X}}} \newcommand{\Xtildep}{\cpsi{\tilde{X}'}}

Better deep learning (sometimes)
by learning kernel mean embeddings

Danica J. Sutherland(she/her)
University of British Columbia (UBC) / Alberta Machine Intelligence Institute (Amii)
slides (but not the talk) about four related projects:
Feng Liu
Wenkai Xu
Hsiao-Yu (Fish) Tung
(+ 6 more…)
Yazhe Li
Roman Pogodin
Michael Arbel
Mikołaj Bińkowski
Li (Kevin) Wenliang
Heiko Strathmann
+ all with Arthur Gretton

LIKE22 - 12 Jan 2022

(Swipe or arrow keys to move through slides; m for a menu to jump; ? to show more.)

Deep learning and deep kernels

  • Deep learning: models usually of form f(x) = w\tp \phi_\psic(x)
    • With a learned \phi_\psic(x) : \mathcal X \to \R^{D}
  • If we fix \psic , have f \in \cH_\psic with k_\psic(x, y) = \phi_\psic(x)\tp \phi_\psic(y)
    • Same idea as NNGP approximation
  • Could train a classifier by:
    • Let \tilde L(\psic) = L(f_\psic^*) , loss of the best f_\psic^* \in H_\psic
    • Learn \psic by following \nabla_\psic \tilde L(\psic) = \nabla_\psic L(f^*_\psic)
  • Generalize to a deep kernel: k_\psic(x, y) = \kappa\left( \phi_\psic(x), \phi_\psic(y) \right)

Normal deep learning \subset deep kernels

  • Take \phi_\psic(x) \in \R as output of last layer
  • k_\psic(x, y) = \phi_\psic(x) \phi_\psic(y) \fragment[1]{+ 1}
  • Final function in \cH_\psic will be a \phi_\psic(x) \fragment[1]{+ b}
  • With logistic loss: this is Platt scaling

So what?

  • This definitely does not say that deep learning is (even approximately) a kernel method
  • …despite what some people might want you think
  • We know theoretically deep learning can learn some things faster than any kernel method [see Malach+ ICML-21 + refs]
  • But deep kernel learning ≠ traditional kernel models
    • exactly like how usual deep learning ≠ linear models

What do deep kernels give us?

  • In “normal” classification: slightly richer function space 🤷
  • Meta-learning: common \phi_\psic , constantly varying f_\psic^*
  • Two-sample testing
    • Simple form of f_\psic^* for cheap permutation testing
  • Self-supervised learning
    • Better understanding of what's really going on, at least
  • Generative modeling with MMD GANs
    • Better gradient for generator to follow (?)
  • Score matching in exponential families (density estimation)
    • Optimize regularization weights, better gradient (?)

Maximum Mean Discrepancy (MMD)

\begin{align} \mmd&_k(\PP, \QQ) = \sup_{\lVert f \rVert_\cH \le 1} \E_{\X \sim \PP}[f(\X)] - \E_{\Y \sim \QQ}[f(\Y)] \\&\fragment[1]{ = \sup_{\lVert f \rVert_\cH \le 1} \E_{\X \sim \PP}[\langle f, \varphi(\X) \rangle_\cH] - \E_{\Y \sim \QQ}[\langle f, \varphi(\Y) \rangle_\cH] } \\&\fragment[2]{ = \sup_{\lVert f \rVert_\cH \le 1} \left\langle f, \E_{\X \sim \PP}[\varphi(\X)] - \E_{\Y \sim \QQ}[\varphi(\Y)] \right\rangle_\cH } \\&\fragment[3]{= %\left\lVert % \E_{\X \sim \PP}[\varphi(\X)] %- \E_{\Y \sim \QQ}[\varphi(\Y)] %\right\rVert_\cH \sup_{\lVert f \rVert_\cH \le 1} \left\langle f, \mu^k_\PP - \mu^k_\QQ \right\rangle_\cH } \fragment[4]{= \left\lVert \mu^k_\PP - \mu^k_\QQ \right\rVert_\cH } %\\ % \fragment[5]{\langle \mu_\PP^k, \,} % & % \fragment[5]{% % \mu_\QQ^k \rangle_\cH % = \E_{\substack{\X \sim \PP\\\Y \sim \QQ}} \langle \varphi(\X), \varphi(\Y) \rangle_\cH % = \E_{\substack{\X \sim \PP\\\Y \sim \QQ}} k(\X, \Y) % } \end{align}

MMD as feature matching

\mmd_k(\PP, \QQ) = \left\lVert \E_{\X \sim \PP}[ \varphi(\X) ] - \E_{\Y \sim \QQ}[ \varphi(\Y) ] \right\rVert_{\cH}

  • \varphi : \Xc \to \cH is the feature map for k(x, y) = \langle \varphi(x), \varphi(y) \rangle
  • If k(x, y) = x\tp y , \varphi(x) = x , then
    the MMD is distance between means
  • Many kernels: infinite-dimensional \cH

Estimating MMD

\begin{gather} \mmd_k^2(\PP, \QQ) % = \E_{\substack{\X, \Xp \sim \PP\\\Y, \Yp \sim \QQ}}\left[ % k(\X, \Xp) % - 2 k(\X, \Y) % + k(\Y, \Yp) % \right] = \E_{\X, \Xp \sim \PP}[k(\X, \Xp)] + \E_{\Y, \Yp \sim \QQ}[k(\Y, \Yp)] - 2 \E_{\substack{\X \sim \PP\\\Y \sim \QQ}}[k(\X, \Y)] \\ \fragment[0]{ \mmdhat_k^2(\X, \Y) = \fragment[1][highlight-current-red]{\mean(K_{\X\X})} + \fragment[2][highlight-current-red]{\mean(K_{\Y\Y})} - 2 \fragment[3][highlight-current-red]{\mean(K_{\X\Y})} } \end{gather}




I: Two-sample testing

  • Given samples from two unknown distributions \X \sim \PP \qquad \Y \sim \QQ
  • Question: is \PP = \QQ ?
  • Hypothesis testing approach: H_0: \PP = \QQ \qquad H_1: \PP \ne \QQ
  • Reject H_0 if test statistic \hat T(\X, \Y) > c_\alpha
  • Do smokers/non-smokers get different cancers?
  • Do Brits have the same friend network types as Americans?
  • When does my laser agree with the one on Mars?
  • Are storms in the 2000s different from storms in the 1800s?
  • Does presence of this protein affect DNA binding? [MMDiff2]
  • Do these dob and birthday columns mean the same thing?
  • Does my generative model \Qtheta match \Pdata ?
  • Independence testing: is P(\X, \Y) = P(\X) P(\Y) ?

What's a hypothesis test again?

2022-01-11T16:33:01.316366 image/svg+xml Matplotlib v3.5.1, https://matplotlib.org/

Permutation testing to find c_\alpha

Need \Pr_{H_0}\left( T(\X, \Y) > c_\alpha \right) \le \alpha


c_\alpha : 1-\alpha th quantile of \left\{ \fragment[1]{\hat T(\cP{\tilde X_1}, \cQ{\tilde Y_1}), \;} \fragment[2]{\hat T(\cP{\tilde X_2}, \cQ{\tilde Y_2}), \;} \fragment[3]{\cdots} \right\}

MMD-based tests

  • If k is characteristic, \mmd(\PP, \QQ) = 0 iff \PP = \QQ
  • Efficient permutation testing for \mmdhat(\X, \Y)
    • H_0 : n \mmdhat^2 converges in distribution
    • H_1 : \sqrt n ( \mmdhat^2 - \mmd^2 ) asymptotically normal
  • Any characteristic kernel gives consistent test…eventually
  • Need enormous n if kernel is bad for problem

Classifier two-sample tests

  • \hat T(\X, \Y) is the accuracy of f on the test set
  • Under H_0 , classification impossible: \hat T \sim \mathrm{Binomial}(n, \frac12)
  • With k(x, y) = \frac14 f(x) f(y) where f(x) \in \{-1, 1\} ,
    get \mmdhat(\X, \Y) = \left\lvert \hat{T}(\X, \Y) - \frac12 \right\rvert

Optimizing test power

  • Asymptotics of \mmdhat^2 give us immediately that \Pr_{H_1}\left( n \mmdhat^2 > c_\alpha \right) \approx \Phi\left( \frac{\sqrt n \mmd^2}{\sigma_{H_1}} - \frac{c_\alpha}{\sqrt n \sigma_{H_1}} \right) \mmd , \sigma_{H_1} , c_\alpha are constants: first term dominates
  • Pick k to maximize an estimate of \mmd^2 / \sigma_{H_1}
  • Can show uniform \mathcal O_P(n^{-\frac13}) convergence of estimator

Blobs dataset

Blobs kernels

Blobs results

CIFAR-10 vs CIFAR-10.1

Train on 1 000, test on 1 031, repeat 10 times. Rejection rates:

Ablation vs classifier-based tests

Cross-entropyMax power
High- d Gauss. mix.0.470.590.290.640.66
MNIST vs GAN0.650.710.800.941.00


  • What if you don't have much data for your testing problem?
  • Need enough data to pick a good kernel
  • Also need enough test data to actually detect the difference
  • Best split depends on best kernel's quality / how hard to find
    • Don't know that ahead of time; can't try more than one


  • One idea: what if we have related problems?
  • Similar setup to meta-learning:(from Wei+ 2018)

Meta-testing for CIFAR-10 vs CIFAR-10.1

  • CIFAR-10 has 60,000 images, but CIFAR-10.1 only has 2,031
  • Where do we get related data from?
  • One option: set up tasks to distinguish classes of CIFAR-10 (airplane vs automobile, airplane vs bird, ...)

One approach (MAML-like)

A_\theta is, e.g., 5 steps of gradient descent

we learn the initialization, maybe step size, etc

This works, but not as well as we'd hoped…
Initialization might work okay on everything, not really adapt

Another approach: Meta-MKL

Inspired by classic multiple kernel learning

Only need to learn linear combination \beta_i on test task:
much easier

Theoretical analysis for Meta-MKL

  • Same big-O dependence on test task size 😐
  • But multiplier is much better:
    based on number of meta-training tasks, not on network size
  • (Analysis assumes meta-tasks are “related” enough)

Results on CIFAR-10.1

Challenges for testing

  • When \PP \ne \QQ , can we tell how they're different?
    • Methods so far: some mostly for low- d
    • Some look at points with large critic function
  • Finding kernels / features that can't do certain things
    • distinguish by emotion, but can't distinguish by skin color
  • Avoid the need for data splitting (selective inference)
    • Kübler+ NeurIPS-20 gave one method
    • only for multiple kernel learning
    • only with data-inefficient (streaming) estimator

II: Self-supervised learning

Given a bunch of unlabeled samples \X ,
want to find “good” features \Z = f(\X)
(e.g. so that a linear classifier on \Z works with few samples)

One common approach: contrastive learning

InfoNCE [van den Oord+ 2018, Poole+ ICML-19]

\begin{align*} \E_{\Z_1}\left[ \E_{\Z_2 \sim \text{pos}}[ k(\Z_1, \Z_2)]] + \log \E_{\Z_2}[ \exp(k(\Z_1, \Z_2)) \right] \\\le \operatorname{MI}(\Z_1, \Z_2) \end{align*}
CPC, SimCLR, MoCo, SwAV, …

Mutual information isn't why SSL works!

InfoNCE approximates MI between "positive" views

But MI is invariant to transformations important to SSL!

Hilbert-Schmidt Independence Criterion (HSIC)

\begin{align*} \hsic(\X, \Y) &= \norm{ \E[ \phi(\X) \otimes \phi(\Y) ] - \E[ \phi(\X) ] \otimes \E[ \phi(\Y) ] }_{HS}^2 \\&= \mmd^2(\PP_{\X\Y}, \PP_\X \otimes \PP_\Y) \\&\fragment{\le C_k \operatorname{MI}(\X, \Y) \qquad\text{($C_k$ depends only on $\norm{k}_\infty$)}} \end{align*}

With a linear kernel: \hsic = \norm{ \E[\X \Y\tp] - \E[\X] \E[\Y]\tp }_F^2

Estimator based on kernel matrices: \hsichat = \frac{1}{(n-1)^2} \tr(\cP{K} H \cQ{L} H)
  • H is the centering matrix
  • \cP{K} is the kernel matrix on \X
  • \cP{L} is the kernel matrix on \Y


\mathcal L_{\text{SSL-HSIC}} = -\hsic(\Z, \Y) + \gamma \sqrt{\hsic(\Z, \Z)}
( \Y is just an indicator of which source image it came from)

Target representation \Z' is output of f_\theta(\X)

HSIC uses learned kernel k(\Z'_1, \Z'_2) = \operatorname{IMQ}(g(\Z'_1), g(\Z'_2))

Your InfoNCE model is secretly a kernel method…

\begin{gather*} \mathcal{L}_{\mathrm{InfoNCE}}(\theta) = - \E_{(\Z_1,\Z_2) \sim \mathrm{pos}}\left[ k(\Z_1, \Z_{2}) \right] + \E_{\Z_1} \log \E_{\Z_2} \left[ \exp k(\Z_{1}, \Z_{2}) \right] \qquad\\\qquad \fragment{ \approx \underbrace{% - \E_{z_1,z_2\sim \mathrm{pos}}\left[k(\Z_1, \Z_2)\right] + \E_{\Z_1} \E_{\Z_2} \left[k(\Z_1, \Z_2) \right]% }_{\propto -\hsic(\Z, \Y)} + \frac12 \underbrace{% \E_{\Z_1} \left[\Var_{\Z_2} \left[ k(\Z_1, \Z_2) \right] \right]% }_{\textrm{variance penalty}} } \end{gather*}

Very similar loss! Just different regularizer

When variance is small,
-\hsic(\Z, \Y) + \gamma \hsic(\Z, \Z) \le \mathcal L_{\mathrm{InfoNCE}} + o(\mathrm{variance})

Clustering interpretation

SSL-HSIC estimates agreement of \Z with cluster structure of \Y

With linear kernels: -\hsic(\Z, \Y) \propto \sum_{i=1}^n \sum_{p=1}^m \norm{\cZ{Z_i^{(p)}} - \cZ{\bar Z_i}}^2 - n m where \cZ{\bar Z_i} is mean of the m augmentations

Resembles BYOL loss with no target network (but still works!)

ImageNet results: linear evaluation

Transfer from ImageNet to classification tasks

III: Training implicit generative models

Given samples from a distribution \PP over \mathcal X ,
we want a model that can produce new samples from \Qtheta \approx \PP

\X \sim \PP

\Y \sim \Qtheta
“Everybody Dance Now” [Chan+ ICCV-19]

Generator networks

Fixed distribution of latents: \Z \sim \mathrm{Uniform}\left( [-1, 1]^{100} \right)

Maps through a network: G_\thetac(\Z) \sim \Qtheta

DCGAN generator [Radford+ ICLR-16]

How to choose \thetac ?

GANs and their flaws

  • GANs [Goodfellow+ NeurIPS-14] minimize discriminator accuracy (like classifier test) between \PP and \Qtheta
  • Problem: if there's a perfect classifier, discontinuous loss, no gradient to improve it [Arjovsky/Bottou ICLR-17]
  • Disjoint at init:
    \PP :
    \cQ{\QQ_{\theta_0}} :
  • For usual \Gtheta : \R^{100} \to \R^{64 \times 64 \times 3} , \Qtheta is supported on a countable union of manifolds with dim \le 100
  • “Natural image manifold” usually considered low-dim
  • Won't align at init, so won't ever align


  • Integral probability metrics with “smooth” \mathcal F are continuous
  • WGAN: \mathcal F a set of neural networks satisfying \norm{f}_L \le 1
  • WGAN-GP: instead penalize \E \norm{\nabla_x f(x)} near the data
  • Both losses are MMD with k_\psic(x, y) = \phi_\psic(x) \phi_\psic(y)
    • \min_\theta \left[ \optmmd(\PP, \Qtheta) = \sup_{\psic \in \Psic} \mmd_{\psic}(\PP, \Qtheta) \right]
  • Some kind of constraint on \phi_\psic is important!

Non-smoothness of plain MMD GANs

Illustrative problem in \R , DiracGAN [Mescheder+ ICML-18]:

2022-01-11T16:33:02.905103 image/svg+xml Matplotlib v3.5.1, https://matplotlib.org/
2022-01-11T16:33:04.064985 image/svg+xml Matplotlib v3.5.1, https://matplotlib.org/
  • Just need to stay away from tiny bandwidths \psi
  • …deep kernel analogue is hard.
  • Instead, keep witness function from being too steep
  • \sup_x \lVert \nabla f(x) \rVert would give Wasserstein
    • Nice distance, but hard to estimate
  • Control \lVert \nabla f(\Xtilde) \rVert on average, near the data

MMD-GAN with gradient control

  • If \Psic gives uniformly Lipschitz critics, \optmmd is smooth
  • Original MMD-GAN paper [Li+ NeurIPS-17]: box constraint
  • We [Bińkowski+ ICLR-18] used gradient penalty on critic instead
    • Better in practice, but doesn't fix the Dirac problem…

New distance: Scaled MMD

Want to ensure \E_{\Xtilde \sim \SS}[ \lVert \nabla f(\Xtilde) \rVert^2 ] \le 1

Can solve with \langle \partial_i \phi(x), f \rangle_{\cH} = \partial_i f(x) …but too expensive!

Guaranteed if \lVert f \rVert_\cH \le \sigma_{\SS,k,\lambda}

\sigma_{\SS,k,\lambda} := \left( \lambda + \E_{\Xtilde \sim \SS}\left[ k(\Xtilde, \Xtilde) + [\nabla_1 \!\cdot\! \nabla_2 k](\Xtilde, \Xtilde) %+ \sum_{i=1}^d \frac{\partial^2 k(y, z)}{\partial y_i \partial z_i} \Bigg\rvert_{(y,z) = (\Xtilde, \Xtilde)} \right] \right)^{-\frac12}

Gives distance \smmd_{\SS,k,\lambda}(\PP, \QQ) = \sigma_{\SS,k,\lambda} \mmd_k(\PP, \QQ)

\begin{align} \optmmd \text{ has } & \mathcal F = \bigcup_{\psic \in \Psic} \left\{ f : \lVert f \rVert_{\h_{\psic}} \le 1 \right\} \\ \optsmmd \text{ has } & \mathcal F = \bigcup_{\psic \in \Psic} \left\{ f : \lVert f \rVert_{\h_{\psic}} \le \sigma_{\SS,k,\lambda} \right\} \end{align}

Deriving the Scaled MMD

\begin{gather} \fragment[0]{\E_{\Xtilde \sim \SS}[ f(\Xtilde)^2 ] + } \E_{\Xtilde \sim \SS}[ \lVert \nabla f(\Xtilde) \rVert^2 ] \fragment[0]{+ \lambda \lVert f \rVert_\h^2} \le 1 \\ \fragment[1]{ \E_{\Xtilde \sim \SS}[ f(\Xtilde)^2 ] = \left\langle f, \E_{\Xtilde \sim \SS}\left[ k(\Xtilde, \cdot) \otimes k(\Xtilde, \cdot) \right] f \right\rangle } \\ \fragment[2]{ \E_{\Xtilde \sim \SS}[ \lVert \nabla f(\Xtilde) \rVert^2 ] = \left\langle f, \E_{\Xtilde \sim \SS}\left[ \sum_{i=1}^d \partial_i k(\Xtilde, \cdot) \otimes \partial_i k(\Xtilde, \cdot) \right] f \right\rangle } \\ \fragment[3]{ \langle f, D_\lambda f \rangle \le \lVert D_\lambda \rVert \, \lVert f \rVert_\h^2 \le \sigma_{\SS,k,\lambda}^{-2} \lVert f \rVert_\h^2 } \end{gather}

Smoothness of \D_\mathrm{SMMD}

2022-01-11T16:33:05.320144 image/svg+xml Matplotlib v3.5.1, https://matplotlib.org/

Theorem: \optsmmd is continuous.

If \SS has a density; \ktop is Gaussian/linear/…;
\phi_\psic is fully-connected, Leaky-ReLU, non-increasing width;
all weights in \Psic have bounded condition number; then

\W(\QQ_n, \PP) \to 0 \text{ implies } \optsmmd(\QQ_n, \PP) \to 0 .