\definecolor{cb1}{RGB}{76,114,176} \definecolor{cb2}{RGB}{221,132,82} \definecolor{cb3}{RGB}{85,168,104} \definecolor{cb4}{RGB}{196,78,82} \definecolor{cb5}{RGB}{129,114,179} \definecolor{cb6}{RGB}{147,120,96} \definecolor{cb7}{RGB}{218,139,195} \newcommand{\abs}[1]{\left\lvert #1 \right\rvert} \newcommand{\norm}[1]{\left\lVert #1 \right\rVert} \DeclareMathOperator*{\argmin}{argmin} \DeclareMathOperator{\bigO}{\mathcal{O}} \DeclareMathOperator{\D}{\mathcal{D}} \DeclareMathOperator*{\E}{\mathbb{E}} \DeclareMathOperator*{\Var}{Var} \DeclareMathOperator{\mean}{mean} \DeclareMathOperator{\sign}{sign} \DeclareMathOperator{\spn}{span} \newcommand{\tp}{^\mathsf{T}} \newcommand{\ud}{\mathrm{d}} \DeclareMathOperator{\W}{\mathcal{W}} \DeclareMathOperator{\KL}{KL} \DeclareMathOperator{\JS}{JS} \DeclareMathOperator{\mmd}{MMD} \DeclareMathOperator{\smmd}{SMMD} \DeclareMathOperator{\mmdhat}{\widehat{MMD}} \newcommand{\optmmd}[1][\Psic]{\operatorname{\mathcal{D}_\mathrm{MMD}^{#1}}} \DeclareMathOperator{\optmmdhat}{\hat{\mathcal{D}}_\mathrm{MMD}} \newcommand{\optsmmdp}{\operatorname{\mathcal{D}_\mathrm{SMMD}}} \newcommand{\optsmmd}[1][\Psic]{\operatorname{\mathcal{D}_\mathrm{SMMD}^{\SS,#1,\lambda}}} \newcommand{\ktop}{k_\mathrm{top}} \newcommand{\lip}{\mathrm{Lip}} \newcommand{\cH}{\mathcal{H}} \newcommand{\h}{\mathcal{H}} \newcommand{\R}{\mathbb{R}} \newcommand{\cP}[1]{{\color{cb1} #1}} \newcommand{\PP}{\cP{\mathbb P}} \newcommand{\pp}{\cP{p}} \newcommand{\X}{\cP{X}} \newcommand{\Xp}{\cP{X'}} % kexpfam colors \newcommand{\nc}{{\color{#d62728}{n}}} \newcommand{\Xc}{{\color{#d62728}{X}}} \newcommand{\Mc}{{\color{#1f77b4}{M}}} \newcommand{\Yc}{{\color{#1f77b4}{Y}}} \newcommand{\mc}{{\color{#17becf}{m}}} \newcommand{\dc}{{\color{#2ca02c}{d}}} \newcommand{\Pdata}{\cP{\mathbb{P}_\mathrm{data}}} \newcommand{\cQ}[1]{{\color{cb2} #1}} \newcommand{\QQ}{\cQ{\mathbb Q}} \newcommand{\qq}{\cQ{q}} \newcommand{\qtheta}{\cQ{q_\theta}} \newcommand{\Y}{\cQ{Y}} \newcommand{\Yp}{\cQ{Y'}} \newcommand{\thetac}{\cQ{\theta}} \newcommand{\vtheta}{\thetac} \newcommand{\Qtheta}{\QQ_\thetac} \newcommand{\Gtheta}{\cQ{G_\theta}} \newcommand{\cZ}[1]{{\color{cb5} #1}} \newcommand{\Z}{\cZ{Z}} \newcommand{\Zc}{\cZ{\mathcal Z}} \newcommand{\ZZ}{\cZ{\mathbb Z}} \newcommand{\cpsi}[1]{{\color{cb3} #1}} \newcommand{\psic}{\cpsi{\psi}} \newcommand{\Psic}{\cpsi{\Psi}} \newcommand{\Dpsi}{\cpsi{D_\psi}} \newcommand{\SS}{\cpsi{\mathbb{S}}} \newcommand{\Xtilde}{\cpsi{\tilde{X}}} \newcommand{\Xtildep}{\cpsi{\tilde{X}'}}

Deep kernel-based distances between distributions

Danica J. Sutherland

Based on work with:
Michael Arbel
Mikołaj Bińkowski
Soumyajit De
Arthur Gretton
Feng Liu
Jie Lu
Aaditya Ramdas
Alex Smola
Heiko Strathmann
Hsiao-Yu (Fish) Tung
Wenkai Xu
Guangquan Zhang
\D\left( \rule{0cm}{1.1cm} \right. , \; \left. \rule{0cm}{1.1cm} \right)

PIHOT kick-off, 30 Jan 2021

(Swipe or arrow keys to move through slides; m for a menu to jump; ? to show more.)

What's a kernel again?

  • Linear classifiers: \hat y(x) = \sign(f(x)) , f(x) = w\tp \left( x, 1 \right)
  • Use a “richer” x : f(x) = w\tp \left( x, x^2, 1\right) = w\tp \phi(x)
  • Can avoid explicit \phi(x) ; instead k(x, y) = \langle \phi(x), \phi(y) \rangle_{\cH}
  • “Kernelized” algorithms access data only through k(x, y) f(x) = \langle w, \phi(x) \rangle_\cH = \sum_{i=1}^n \alpha_i k(X_i, x)

Reproducing Kernel Hilbert Space (RKHS)

  • Ex: Gaussian RBF / exponentiated quadratic / squared exponential / … k(x, y) = \exp\left( - \frac{\norm{x - y}^2}{2 \sigma^2} \right)
  • Reproducing property: \langle f, \phi(x) \rangle_\cH = f(x) for f \in \cH
  • \cH = \operatorname{cl}\left( \left\{ \sum_{i=1}^n \alpha_i \phi(X_i) \mid n \ge 0, \alpha \in \R^n, X_i \in \mathcal X \right\} \right)
  • \norm{\sum_i \alpha_i \phi(X_i)}_\cH^2 = \alpha\tp K \alpha , where K_{ij} = k(X_i, X_j)
  • \argmin_{f \in \cH} L(f(X_1), \dots, f(X_n)) + \lambda \norm{f}_\cH^2 is in \left\{ \sum_{i=1}^n \alpha_i \phi(X_i) \mid \alpha \in \R^n \right\} – the representer theorem

Maximum Mean Discrepancy (MMD)

\begin{align} \mmd&_k(\PP, \QQ) = \sup_{\lVert f \rVert_\cH \le 1} \E_{\X \sim \PP}[f(\X)] - \E_{\Y \sim \QQ}[f(\Y)] \\&\fragment[1]{ = \sup_{\lVert f \rVert_\cH \le 1} \E_{\X \sim \PP}[\langle f, \varphi(\X) \rangle_\cH] - \E_{\Y \sim \QQ}[\langle f, \varphi(\Y) \rangle_\cH] } \\&\fragment[2]{ = \sup_{\lVert f \rVert_\cH \le 1} \left\langle f, \E_{\X \sim \PP}[\varphi(\X)] - \E_{\Y \sim \QQ}[\varphi(\Y)] \right\rangle_\cH } \\&\fragment[3]{= %\left\lVert % \E_{\X \sim \PP}[\varphi(\X)] %- \E_{\Y \sim \QQ}[\varphi(\Y)] %\right\rVert_\cH \sup_{\lVert f \rVert_\cH \le 1} \left\langle f, \mu^k_\PP - \mu^k_\QQ \right\rangle_\cH } \fragment[4]{= \left\lVert \mu^k_\PP - \mu^k_\QQ \right\rVert_\cH } %\\ % \fragment[5]{\langle \mu_\PP^k, \,} % & % \fragment[5]{% % \mu_\QQ^k \rangle_\cH % = \E_{\substack{\X \sim \PP\\\Y \sim \QQ}} \langle \varphi(\X), \varphi(\Y) \rangle_\cH % = \E_{\substack{\X \sim \PP\\\Y \sim \QQ}} k(\X, \Y) % } \end{align}

MMD as feature matching

\mmd_k(\PP, \QQ) = \left\lVert \E_{\X \sim \PP}[ \varphi(\X) ] - \E_{\Y \sim \QQ}[ \varphi(\Y) ] \right\rVert_{\cH}

  • \varphi : \Xc \to \cH is the feature map for k(x, y) = \langle \varphi(x), \varphi(y) \rangle
  • If k(x, y) = x\tp y , \varphi(x) = x ; MMD is distance between means
  • Many kernels: infinite-dimensional \cH

MMD and OT

Estimating MMD

\begin{gather} \mmd_k^2(\PP, \QQ) % = \E_{\substack{\X, \Xp \sim \PP\\\Y, \Yp \sim \QQ}}\left[ % k(\X, \Xp) % - 2 k(\X, \Y) % + k(\Y, \Yp) % \right] = \E_{\X, \Xp \sim \PP}[k(\X, \Xp)] + \E_{\Y, \Yp \sim \QQ}[k(\Y, \Yp)] - 2 \E_{\substack{\X \sim \PP\\\Y \sim \QQ}}[k(\X, \Y)] \\ \fragment[0]{ \mmdhat_k^2(\X, \Y) = \fragment[1][highlight-current-red]{\mean(K_{\X\X})} + \fragment[2][highlight-current-red]{\mean(K_{\Y\Y})} - 2 \fragment[3][highlight-current-red]{\mean(K_{\X\Y})} } \end{gather}

K_{\X\X}

1.00.20.6
0.21.00.5
0.60.51.0

K_{\Y\Y}

1.00.80.7
0.81.00.6
0.70.61.0

K_{\X\Y}

0.30.10.2
0.20.30.3
0.20.10.4

I: Two-sample testing

  • Given samples from two unknown distributions \X \sim \PP \qquad \Y \sim \QQ
  • Question: is \PP = \QQ ?
  • Hypothesis testing approach: H_0: \PP = \QQ \qquad H_1: \PP \ne \QQ
  • Reject H_0 if test statistic \hat T(\X, \Y) > c_\alpha
  • Do smokers/non-smokers get different cancers?
  • Do Brits have the same friend network types as Americans?
  • When does my laser agree with the one on Mars?
  • Are storms in the 2000s different from storms in the 1800s?
  • Does presence of this protein affect DNA binding? [MMDiff2]
  • Do these dob and birthday columns mean the same thing?
  • Does my generative model \Qtheta match \Pdata ?
  • Independence testing: is P(\X, \Y) = P(\X) P(\Y) ?

What's a hypothesis test again?

Permutation testing to find c_\alpha

Need \Pr_{H_0}\left( T(\X, \Y) > c_\alpha \right) \le \alpha

X_1
X_2
X_3
X_4
X_5
Y_1
Y_2
Y_3
Y_4
Y_5

c_\alpha : 1-\alpha th quantile of \left\{ \fragment[1]{\hat T(\cP{\tilde X_1}, \cQ{\tilde Y_1}), \;} \fragment[2]{\hat T(\cP{\tilde X_2}, \cQ{\tilde Y_2}), \;} \fragment[3]{\cdots} \right\}

MMD-based tests

  • If k is characteristic, \mmd(\PP, \QQ) = 0 iff \PP = \QQ
  • Efficient permutation testing for \mmdhat(\X, \Y)
    • H_0 : n \mmdhat^2 converges in distribution
    • H_1 : \sqrt n ( \mmdhat^2 - \mmd^2 ) asymptotically normal
  • Any characteristic kernel gives consistent test…eventually
  • Need enormous n if kernel is bad for problem

Classifier two-sample tests

  • \hat T(\X, \Y) is the accuracy of f on the test set
  • Under H_0 , classification impossible: \hat T \sim \mathrm{Binomial}(n, \frac12)
  • With k(x, y) = \frac14 f(x) f(y) where f(x) \in \{-1, 1\} ,
    get \mmdhat(\X, \Y) = \left\lvert \hat{T}(\X, \Y) - \frac12 \right\rvert

Optimizing test power

  • Asymptotics of \mmdhat^2 give us immediately that \Pr_{H_1}\left( n \mmdhat^2 > c_\alpha \right) \approx \Phi\left( \frac{\sqrt n \mmd^2}{\sigma_{H_1}} - \frac{c_\alpha}{\sqrt n \sigma_{H_1}} \right) \mmd , \sigma_{H_1} , c_\alpha are constants: first term dominates
  • Pick k to maximize an estimate of \mmd^2 / \sigma_{H_1}
  • Can show uniform \mathcal O_P(n^{-\frac13}) convergence of estimator

Blobs dataset

Blobs kernels

Blobs results

CIFAR-10 vs CIFAR-10.1

Train on 1 000, test on 1 031, repeat 10 times. Rejection rates:
MESCFC2STMMD-OMMD-D
0.5880.1710.4520.3160.744

Ablation vs classifier-based tests

Cross-entropyMax power
DatasetSignLinOursSignLinOurs
Blob0.840.940.900.950.99
High- d Gauss. mix.0.470.590.290.640.66
Higgs0.260.400.350.300.40
MNIST vs GAN0.650.710.800.941.00

II: Training implicit generative models

Given samples from a distribution \PP over \mathcal X ,
we want a model that can produce new samples from \Qtheta \approx \PP


\X \sim \PP

\Y \sim \Qtheta
“Everybody Dance Now” [Chan et al. ICCV-19]

Generator networks

Fixed distribution of latents: \Z \sim \mathrm{Uniform}\left( [-1, 1]^{100} \right)

Maps through a network: G_\thetac(\Z) \sim \Qtheta

DCGAN generator [Radford+ ICLR-16]

How to choose \thetac ?

GANs and their flaws

  • GANs [Goodfellow+ NeurIPS-14] minimize discriminator accuracy (like classifier test) between \PP and \Qtheta
  • Problem: if there's a perfect classifier, discontinuous loss, no gradient to improve it [Arjovsky/Bottou ICLR-17]
  • Disjoint at init:
    \PP :
    \cQ{\QQ_{\theta_0}} :
  • For usual \Gtheta : \R^{100} \to \R^{64 \times 64 \times 3} , \Qtheta is supported on a countable union of manifolds with dim \le 100
  • “Natural image manifold” usually considered low-dim
  • Won't align at init, so won't ever align

WGANs and MMD GANs

  • Integral probability metrics with “smooth” \mathcal F are continuous
  • WGAN: \mathcal F a set of neural networks satisfying \norm{f}_L \le 1
  • WGAN-GP: instead penalize \E \norm{\nabla_x f(x)} near the data
  • Both losses are MMD with k_\psic(x, y) = \phi_\psic(x) \phi_\psic(y)
    • \min_\theta \left[ \optmmd(\PP, \Qtheta) = \sup_{\psic \in \Psic} \mmd_{\psic}(\PP, \Qtheta) \right]
  • Some kind of constraint on \phi_\psic is important!

Non-smoothness of plain MMD GANs

Illustrative problem in \R , DiracGAN [Mescheder+ ICML-18]:

  • Just need to stay away from tiny bandwidths \psi
  • …deep kernel analogue is hard.
  • Instead, keep witness function from being too steep
  • \sup_x \lVert \nabla f(x) \rVert would give Wasserstein
    • Nice distance, but hard to estimate
  • Control \lVert \nabla f(\Xtilde) \rVert on average, near the data

MMD-GAN with gradient control

  • If \Psic gives uniformly Lipschitz critics, \optmmd is smooth
  • Original MMD-GAN paper [Li+ NeurIPS-17]: box constraint
  • We [Bińkowski+ ICLR-18] used gradient penalty on critic instead
    • Better in practice, but doesn't fix the Dirac problem…

New distance: Scaled MMD

Want to ensure \E_{\Xtilde \sim \SS}[ \lVert \nabla f(\Xtilde) \rVert^2 ] \le 1

Can solve with \langle \partial_i \phi(x), f \rangle_{\cH} = \partial_i f(x) …but too expensive!

Guaranteed if \lVert f \rVert_\cH \le \sigma_{\SS,k,\lambda}

\sigma_{\SS,k,\lambda} := \left( \lambda + \E_{\Xtilde \sim \SS}\left[ k(\Xtilde, \Xtilde) + [\nabla_1 \!\cdot\! \nabla_2 k](\Xtilde, \Xtilde) %+ \sum_{i=1}^d \frac{\partial^2 k(y, z)}{\partial y_i \partial z_i} \Bigg\rvert_{(y,z) = (\Xtilde, \Xtilde)} \right] \right)^{-\frac12}

Gives distance \smmd_{\SS,k,\lambda}(\PP, \QQ) = \sigma_{\SS,k,\lambda} \mmd_k(\PP, \QQ)

\begin{align} \optmmd \text{ has } & \mathcal F = \bigcup_{\psic \in \Psic} \left\{ f : \lVert f \rVert_{\h_{\psic}} \le 1 \right\} \\ \optsmmd \text{ has } & \mathcal F = \bigcup_{\psic \in \Psic} \left\{ f : \lVert f \rVert_{\h_{\psic}} \le \sigma_{\SS,k,\lambda} \right\} \end{align}

Deriving the Scaled MMD

\begin{gather} \fragment[0]{\E_{\Xtilde \sim \SS}[ f(\Xtilde)^2 ] + } \E_{\Xtilde \sim \SS}[ \lVert \nabla f(\Xtilde) \rVert^2 ] \fragment[0]{+ \lambda \lVert f \rVert_\h^2} \le 1 \\ \fragment[1]{ \E_{\Xtilde \sim \SS}[ f(\Xtilde)^2 ] = \left\langle f, \E_{\Xtilde \sim \SS}\left[ k(\Xtilde, \cdot) \otimes k(\Xtilde, \cdot) \right] f \right\rangle } \\ \fragment[2]{ \E_{\Xtilde \sim \SS}[ \lVert \nabla f(\Xtilde) \rVert^2 ] = \left\langle f, \E_{\Xtilde \sim \SS}\left[ \sum_{i=1}^d \partial_i k(\Xtilde, \cdot) \otimes \partial_i k(\Xtilde, \cdot) \right] f \right\rangle } \\ \fragment[3]{ \langle f, D_\lambda f \rangle \le \lVert D_\lambda \rVert \, \lVert f \rVert_\h^2 \le \sigma_{\SS,k,\lambda}^{-2} \lVert f \rVert_\h^2 } \end{gather}

Smoothness of \D_\mathrm{SMMD}

Theorem: \optsmmd is continuous.

If \SS has a density; \ktop is Gaussian/linear/…;
\phi_\psic is fully-connected, Leaky-ReLU, non-increasing width;
all weights in \Psic have bounded condition number; then

\W(\QQ_n, \PP) \to 0 \text{ implies } \optsmmd(\QQ_n, \PP) \to 0 .

Results on 160 \times 160 CelebA

SN-SMMD-GANKID: 0.006
WGAN-GPKID: 0.022

Training process on CelebA

image/svg+xml

Evaluating generative models

Evaluating generative models

  • Human evaluation: good at precision, bad at recall
  • Likelihood: hard for GANs, maybe not right thing anyway
  • Two-sample tests: always reject!
  • Most common: Fréchet Inception Distance, FID
    • Run pretrained featurizer on model and target
    • Model each as Gaussian; compute W_2
    • Strong bias, small variance: very misleading
    • Simple examples where \operatorname{FID}(\cQ{\QQ_1}) > \operatorname{FID}(\cQ{\QQ_2}) but \widehat{\operatorname{FID}}(\cQ{\hat\QQ_1}) < \widehat{\operatorname{FID}}(\cQ{\hat\QQ_2}) for reasonable sample size
  • Our KID: \mmd^2 instead. Unbiased, asymptotically normal

Recap

Combining a deep architecture with a kernel machine that takes the higher-level learned representation as input can be quite powerful.

— Y. Bengio & Y. LeCun (2007), “Scaling Learning Algorithms towards AI

  • Two-sample testing [ICLR-17, ICML-20]
    • Choose \psic to maximize power criterion
    • Exploit closed form of f^*_\psic for permutation testing
  • Generative modeling with MMD GANs [ICLR-18, NeurIPS-18]
    • Need a smooth loss function for the generator
    • Better gradients for generator to follow (?)

Future uses of deep kernel distances

  • Selective inference to avoid train/test split? Meta-testing?
  • When \PP \ne \QQ , can we tell how they're different?
    • Methods so far: some mostly for low- d
    • Some look at points with large critic function
  • Does model \QQ match dataset \X (Stein testing)?
  • Maximize deep dependence measure for unsupervised representation learning, as in contrastive learning