by learning kernel mean embeddings

University of British Columbia (UBC) / Alberta Machine Intelligence Institute (Amii)

slides (but not the talk) about four related projects:

Feng Liu

Wenkai Xu

Hsiao-Yu (Fish) Tung

(+ 6 more…)

Yazhe Li

Roman Pogodin

Michael Arbel

Mikołaj Bińkowski

Li (Kevin) Wenliang

Heiko Strathmann

+ all with Arthur Gretton

LIKE22 - 12 Jan 2022

(Swipe or arrow keys to move through slides; for a menu to jump; to show more.)

- Deep learning: models usually of form
- With a
*learned*

- With a
- If we fix , have with
- Same idea as NNGP approximation

- Could train a classifier by:
- Let , loss of the best
- Learn by following

- Generalize to a
**deep kernel**:

- Take as output of
*last*layer - Final function in will be
- With logistic loss: this is Platt scaling

- This definitely does
*not*say that deep learning is (even approximately) a kernel method - …despite what some people might want you think
- We know theoretically deep learning can learn some things faster than any kernel method [see Malach+ ICML-21 + refs]
- But deep kernel learning ≠ traditional kernel models
- exactly like how usual deep learning ≠ linear models

- In “normal” classification: slightly richer function space 🤷
- Meta-learning: common , constantly varying
- Two-sample testing
- Simple form of for cheap permutation testing

- Self-supervised learning
- Better understanding of what's really going on, at least

- Generative modeling with MMD GANs
- Better gradient for generator to follow (?)

- Score matching in exponential families (density estimation)
- Optimize regularization weights, better gradient (?)

- is the
*feature map*for - If , , then

the MMD is distance between means - Many kernels:
**infinite-dimensional**

1.0 | 0.2 | 0.6 | |

0.2 | 1.0 | 0.5 | |

0.6 | 0.5 | 1.0 |

1.0 | 0.8 | 0.7 | |

0.8 | 1.0 | 0.6 | |

0.7 | 0.6 | 1.0 |

0.3 | 0.1 | 0.2 | |

0.2 | 0.3 | 0.3 | |

0.2 | 0.1 | 0.4 |

- Given samples from two unknown distributions
- Question: is ?
- Hypothesis testing approach:
- Reject if test statistic

- Do smokers/non-smokers get different cancers?
- Do Brits have the same friend network types as Americans?
- When does my laser agree with the one on Mars?
- Are storms in the 2000s different from storms in the 1800s?
- Does presence of this protein affect DNA binding? [
`MMDiff2`] - Do these
`dob`and`birthday`columns mean the same thing? - Does my generative model match ?
- Independence testing: is ?

Need

: th quantile of

- If is
*characteristic*, iff - Efficient permutation testing for
- : converges in distribution
- : asymptotically normal

- Any characteristic kernel gives consistent test…eventually
- Need enormous if kernel is bad for problem

- is the accuracy of on the test set
- Under , classification impossible:
- With where ,

get

- Asymptotics of give us immediately that , , are constants: first term dominates
- Pick to maximize an estimate of
- Can show uniform convergence of estimator

Train on 1 000, test on 1 031, repeat 10 times. Rejection rates:

ME | SCF | C2ST | MMD-O | MMD-D |
---|---|---|---|---|

0.588 | 0.171 | 0.452 | 0.316 | 0.744 |

Cross-entropy | Max power | |||||
---|---|---|---|---|---|---|

Dataset | Sign | Lin | Ours | Sign | Lin | Ours |

Blobs | 0.84 | 0.94 | 0.90 | – | 0.95 | 0.99 |

High- Gauss. mix. | 0.47 | 0.59 | 0.29 | – | 0.64 | 0.66 |

Higgs | 0.26 | 0.40 | 0.35 | – | 0.30 | 0.40 |

MNIST vs GAN | 0.65 | 0.71 | 0.80 | – | 0.94 | 1.00 |

- What if you don't have much data for your testing problem?
- Need enough data to pick a good kernel
- Also need enough test data to actually detect the difference
- Best split depends on best kernel's quality / how hard to find
- Don't know that ahead of time; can't try more than one

- One idea: what if we have
*related*problems? - Similar setup to meta-learning:(from Wei+ 2018)

- CIFAR-10 has 60,000 images, but CIFAR-10.1 only has 2,031
- Where do we get related data from?
- One option: set up tasks to distinguish classes of CIFAR-10 (airplane vs automobile, airplane vs bird, ...)

is, e.g., 5 steps of gradient descent

we learn the initialization, maybe step size, etc

we learn the initialization, maybe step size, etc

This works, but not as well as we'd hoped…

Initialization might work okay on everything, not really adapt

Inspired by classic multiple kernel learning

Only need to learn linear combination on test task:

much easier

Only need to learn linear combination on test task:

much easier

- Same big-O dependence on test task size 😐
- But multiplier is
*much*better:

based on number of meta-training tasks, not on network size - (Analysis assumes meta-tasks are “related” enough)

- When , can we tell
*how*they're different?- Methods so far: some mostly for low-
- Some look at points with large critic function

- Finding kernels / features that
*can't*do certain things- distinguish by emotion, but can't distinguish by skin color

- Avoid the need for data splitting (selective inference)
- Kübler+ NeurIPS-20 gave one method
- only for multiple kernel learning
- only with data-inefficient (streaming) estimator

Given a bunch of unlabeled samples ,

want to find “good” features

(e.g. so that a linear classifier on works with few samples)

One common approach: contrastive learning

Variants:

CPC, SimCLR, MoCo, SwAV, …

CPC, SimCLR, MoCo, SwAV, …

InfoNCE approximates MI between "positive" views

But MI is invariant to transformations important to SSL!

With a linear kernel:

Estimator based on kernel matrices:

- is the centering matrix
- is the kernel matrix on
- is the kernel matrix on

( is just an indicator of which source image it came from)

Target representation is output of

HSIC uses learned kernel

Very similar loss! Just different regularizer

When variance is small,

SSL-HSIC estimates agreement of with cluster structure of

With linear kernels: where is mean of the augmentations

Resembles BYOL loss with no target network (but still works!)

Given samples from a distribution over ,

we want a model that can produce new samples from

“Everybody Dance Now” [Chan+ ICCV-19]

Fixed distribution of latents:

Maps through a network:

DCGAN generator [Radford+ ICLR-16]

How to choose ?

- GANs [Goodfellow+ NeurIPS-14] minimize discriminator accuracy (like classifier test) between and
- Problem: if there's a perfect classifier, discontinuous loss, no gradient to improve it [Arjovsky/Bottou ICLR-17]
- Disjoint at init:::
- For usual , is supported on a countable union of manifolds with dim
- “Natural image manifold” usually considered low-dim
- Won't align at init, so won't ever align

- Integral probability metrics with “smooth” are continuous
- WGAN: a set of neural networks satisfying
- WGAN-GP: instead penalize near the data
- Both losses are MMD with
- Some kind of constraint on is important!

Illustrative problem in , DiracGAN [Mescheder+ ICML-18]:

- Just need to stay away from tiny bandwidths
- …deep kernel analogue is hard.
- Instead, keep witness function from being too steep
- would give Wasserstein
- Nice distance, but hard to estimate

- Control
*on average, near the data*

- If gives uniformly Lipschitz critics, is smooth
- Original MMD-GAN paper [Li+ NeurIPS-17]: box constraint
- We [Bińkowski+ ICLR-18] used gradient penalty on critic instead
- Better in practice, but doesn't fix the Dirac problem…

Want to ensure

Can solve with …but too expensive!

Guaranteed if

Gives distance

**Theorem:** is continuous.

If has a density;
is Gaussian/linear/…;

is fully-connected, Leaky-ReLU, non-increasing width;

all weights in have bounded condition number;
then