Better deep learning (sometimes)
by learning kernel mean embeddings
Danica J. Sutherland(she/her)University of British Columbia (UBC) / Alberta Machine Intelligence Institute (Amii)
slides (but not the talk) about four related projects:
![]()
Feng Liu
Wenkai Xu
Hsiao-Yu (Fish) Tung
(+ 6 more…)
![]()
Yazhe Li
Roman Pogodin
![]()
Michael Arbel
Mikołaj Bińkowski
![]()
Li (Kevin) Wenliang
Heiko Strathmann
+ all with Arthur Gretton
LIKE22 - 12 Jan 2022
(Swipe or arrow keys to move through slides; for a menu to jump; to show more.)
Deep learning and deep kernels
- Deep learning: models usually of form
- With a learned
- If we fix , have with
- Same idea as NNGP approximation
- Could train a classifier by:
- Let , loss of the best
- Learn by following
- Generalize to a deep kernel:
Normal deep learning deep kernels
- Take as output of last layer
- Final function in will be
- With logistic loss: this is Platt scaling
![]()
So what?
- This definitely does not say that deep learning is (even approximately) a kernel method
- …despite what some people might want you think
![]()
- We know theoretically deep learning can learn some things faster than any kernel method [see Malach+ ICML-21 + refs]
- But deep kernel learning ≠ traditional kernel models
- exactly like how usual deep learning ≠ linear models
What do deep kernels give us?
- In “normal” classification: slightly richer function space 🤷
- Meta-learning: common , constantly varying
- Two-sample testing
- Simple form of for cheap permutation testing
- Self-supervised learning
- Better understanding of what's really going on, at least
- Generative modeling with MMD GANs
- Better gradient for generator to follow (?)
- Score matching in exponential families (density estimation)
- Optimize regularization weights, better gradient (?)
Maximum Mean Discrepancy (MMD)
MMD as feature matching
- is the feature map for
- If , , then
the MMD is distance between means - Many kernels: infinite-dimensional
Estimating MMD
I: Two-sample testing
- Given samples from two unknown distributions
- Question: is ?
- Hypothesis testing approach:
- Reject if test statistic
- Do smokers/non-smokers get different cancers?
- Do Brits have the same friend network types as Americans?
- When does my laser agree with the one on Mars?
- Are storms in the 2000s different from storms in the 1800s?
- Does presence of this protein affect DNA binding? [MMDiff2]
- Do these dob and birthday columns mean the same thing?
- Does my generative model match ?
- Independence testing: is ?
What's a hypothesis test again?
Permutation testing to find
Need
: th quantile of
MMD-based tests
- If is characteristic, iff
- Efficient permutation testing for
- : converges in distribution
- : asymptotically normal
- Any characteristic kernel gives consistent test…eventually
- Need enormous if kernel is bad for problem
Classifier two-sample tests
![]()
- is the accuracy of on the test set
- Under , classification impossible:
- With where ,
get
Optimizing test power
- Asymptotics of give us immediately that
, , are constants:
first term dominates
- Pick to maximize an estimate of
- Can show uniform convergence of estimator
Blobs dataset
![]()
Blobs results
![]()
CIFAR-10 vs CIFAR-10.1
Train on 1 000, test on 1 031, repeat 10 times. Rejection rates:ME | SCF | C2ST | MMD-O | MMD-D |
---|
0.588 | 0.171 | 0.452 | 0.316 | 0.744 |
Ablation vs classifier-based tests
| Cross-entropy | Max power |
---|
Dataset | Sign | Lin | Ours | Sign | Lin | Ours |
---|
Blobs | 0.84 | 0.94 | 0.90 | – | 0.95 | 0.99 |
---|
High- Gauss. mix. | 0.47 | 0.59 | 0.29 | – | 0.64 | 0.66 |
---|
Higgs | 0.26 | 0.40 | 0.35 | – | 0.30 | 0.40 |
---|
MNIST vs GAN | 0.65 | 0.71 | 0.80 | – | 0.94 | 1.00 |
---|
But…
- What if you don't have much data for your testing problem?
- Need enough data to pick a good kernel
- Also need enough test data to actually detect the difference
- Best split depends on best kernel's quality / how hard to find
- Don't know that ahead of time; can't try more than one
Meta-testing
- One idea: what if we have related problems?
- Similar setup to meta-learning:
(from Wei+ 2018)
Meta-testing for CIFAR-10 vs CIFAR-10.1
- CIFAR-10 has 60,000 images, but CIFAR-10.1 only has 2,031
- Where do we get related data from?
- One option: set up tasks to distinguish classes of CIFAR-10 (airplane vs automobile, airplane vs bird, ...)
One approach (MAML-like)
![]()
is, e.g., 5 steps of gradient descent
we learn the initialization, maybe step size, etc
![]()
This works, but not as well as we'd hoped…
Initialization might work okay on everything, not really adapt
Another approach: Meta-MKL
![]()
Inspired by classic multiple kernel learning
Only need to learn linear combination
on test task:
much easier
![]()
Theoretical analysis for Meta-MKL
- Same big-O dependence on test task size 😐
- But multiplier is much better:
based on number of meta-training tasks, not on network size - (Analysis assumes meta-tasks are “related” enough)
Results on CIFAR-10.1
![]()
Challenges for testing
- When , can we tell how they're different?
- Methods so far: some mostly for low-
- Some look at points with large critic function
- Finding kernels / features that can't do certain things
- distinguish by emotion, but can't distinguish by skin color
- Avoid the need for data splitting (selective inference)
- Kübler+ NeurIPS-20 gave one method
- only for multiple kernel learning
- only with data-inefficient (streaming) estimator
II: Self-supervised learning
Given a bunch of unlabeled samples ,
want to find “good” features
(e.g. so that a linear classifier on works with few samples)
One common approach: contrastive learning
![]()
![]()
Variants:
CPC, SimCLR, MoCo, SwAV, …
Mutual information isn't why SSL works!
InfoNCE approximates MI between "positive" views
But MI is invariant to transformations important to SSL!
![]()
Hilbert-Schmidt Independence Criterion (HSIC)
With a linear kernel:
Estimator based on kernel matrices:
- is the centering matrix
- is the kernel matrix on
- is the kernel matrix on
SSL-HSIC
(
is just an indicator of which source image it came from)
![]()
Target representation is output of
HSIC uses learned kernel
Your InfoNCE model is secretly a kernel method…
Very similar loss! Just different regularizer
When variance is small,
Clustering interpretation
SSL-HSIC estimates agreement of with cluster structure of
With linear kernels:
where is mean of the augmentations
Resembles BYOL loss with no target network (but still works!)
ImageNet results: linear evaluation
![]()
Transfer from ImageNet to classification tasks
![]()
III: Training implicit generative models
Given samples from a distribution over ,
we want a model that can produce new samples from
![]()
![]()
Generator networks
Fixed distribution of latents:
Maps through a network:
![]()
DCGAN generator [Radford+ ICLR-16]
How to choose ?
GANs and their flaws
- GANs [Goodfellow+ NeurIPS-14] minimize discriminator accuracy (like classifier test) between and
- Problem: if there's a perfect classifier, discontinuous loss, no gradient to improve it [Arjovsky/Bottou ICLR-17]
Disjoint at init:
:
![]()
:
![]()
- For usual , is supported on a countable union of manifolds with dim
- “Natural image manifold” usually considered low-dim
- Won't align at init, so won't ever align
WGANs and MMD GANs
- Integral probability metrics with “smooth” are continuous
- WGAN: a set of neural networks satisfying
- WGAN-GP: instead penalize near the data
- Both losses are MMD with
- Some kind of constraint on is important!
Non-smoothness of plain MMD GANs
Illustrative problem in , DiracGAN [Mescheder+ ICML-18]:
![]()
- Just need to stay away from tiny bandwidths
- …deep kernel analogue is hard.
- Instead, keep witness function from being too steep
- would give Wasserstein
- Nice distance, but hard to estimate
- Control on average, near the data
MMD-GAN with gradient control
- If gives uniformly Lipschitz critics, is smooth
- Original MMD-GAN paper [Li+ NeurIPS-17]: box constraint
- We [Bińkowski+ ICLR-18] used gradient penalty on critic instead
- Better in practice, but doesn't fix the Dirac problem…
New distance: Scaled MMD
Want to ensure
Can solve with …but too expensive!
Guaranteed if
Gives distance
Deriving the Scaled MMD
Smoothness of
![]()
Theorem: is continuous.
If has a density;
is Gaussian/linear/…;
is fully-connected, Leaky-ReLU, non-increasing width;
all weights in have bounded condition number;
then
Results on CelebA
SN-SMMD-GAN
![]()
KID: 0.006
WGAN-GP
![]()
KID: 0.022
Evaluating generative models
![]()
Evaluating generative models
- Human evaluation: good at precision, bad at recall
- Likelihood: hard for GANs, maybe not right thing anyway
- Two-sample tests: always reject!
- Most common: Fréchet Inception Distance, FID
- Run pretrained featurizer on model and target
- Model each as Gaussian; compute
- Strong bias, small variance: very misleading
- Simple examples where
but
for reasonable sample size
- Our KID: instead. Unbiased, asymptotically normal
Training process on CelebA
IV: Unnormalized density/score estimation
- Problem: given samples with density
- Model is kernel exponential family: for any ,
i.e. any density with
- Gaussian : dense in all continuous distributions on compact domains
Density estimation with KEFs
- Fitting with maximum likelihood is tough:
- , are tough to compute
- Likelihood equations ill-posed for characteristic kernels
- We choose to fit the unnormalized model
- Could then estimate once after fitting if necessary
Unnormalized density / score estimation
- Don't necessarily need to compute afterwards
- , the “energy”, lets us:
- Find modes (global or local)
- Sample (with MCMC)
- …
- The score, , lets us:
- Run HMC for targets whose gradients we can't evaluate
- Construct Monte Carlo control functionals
- …
- Idea: minimize Fisher divergence
- Under mild assumptions,
- Can estimate with Monte Carlo
- Minimize regularized loss function:
- Representer theorem tells us minimizer of over is
![]()
![]()
Score matching in a subspace
- Best is in
- Find best
in dim subspace
in time
- : , time!
Meta-learning a kernel
- This was all with a fixed kernel and
- Good results need carefully tuned kernel and
- We can use a deep kernel as long as we split the data
- Otherwise it would always pick bandwidth 0
![]()
Results
- Learns local dataset geometry: better fits
- On real data: slightly worse likelihoods, maybe better “shapes” than deep likelihood models
Recap
Combining a deep architecture with a kernel machine that takes the higher-level learned representation as input can be quite powerful.
— Y. Bengio & Y. LeCun (2007), “Scaling Learning Algorithms towards AI”
- Two-sample testing [ICLR-17, ICML-20, NeurIPS-21]
- maximizing power criterion, for one task or many
- Self-supervised learning with HSIC [NeurIPS-21]
- Much better understanding of what's going on!
- Generative modeling with MMD GANs [ICLR-18, NeurIPS-18]
- Need a smooth loss function for the generator
- Score matching in exponential families [AISTATS-18, ICML-19]
- Avoid overfitting with closed-form fit on held-out data