Are these datasets the same?
Learning kernels for efficient and fair two-sample tests
Danica J. Sutherland(she/her)University of British Columbia (UBC) / Alberta Machine Intelligence Institute (Amii)
![]()
Hsiao-Yu (Fish) Tung
Heiko Strathmann
Soumyajit De
Aaditya Ramdas
Alex Smola
Arthur Gretton
![]()
Feng Liu
Wenkai Xu
Jie Lu
Guangquan Zhang
Arthur Gretton
![]()
Feng Liu
Wenkai Xu
Jie Lu
new!
![]()
Namrata Deka
TrustML - 15 Feb 2022
(Swipe or arrow keys to move through slides; for a menu to jump; to show more.)
Data drift
- The textbook ML setting:
- Train on i.i.d. samples from some distribution,
- Training error test error on
- So our model should be good on more samples from
- Really:
- Train on “i.i.d. samples from some distribution, ”
- Training error might vaguely correlate with test error on
- Deploy it on some distribution , might be sort of like
- and probably changes over time…
This talk
Based on samples and :
- How is different from ?
- Is close enough to for our model?
- Is ?
Two-sample testing
- Given samples from two unknown distributions
- Question: is ?
- Hypothesis testing approach:
- Reject if test statistic
- Do smokers/non-smokers get different cancers?
- Do Brits have the same friend network types as Americans?
- When does my laser agree with the one on Mars?
- Are storms in the 2000s different from storms in the 1800s?
- Does presence of this protein affect DNA binding? [MMDiff2]
- Do these dob and birthday columns mean the same thing?
- Does my generative model match ?
- Independence testing: is ?
What's a hypothesis test again?
Permutation testing to find
Need
: th quantile of
Need a to estimate the difference between distributions, based on samples
Our choice of : the Maximum Mean Discrepancy (MMD)
This is a kernel-based distance between distributions
What's a kernel again?
- Linear classifiers: ,
- Use a “richer” :
- Can avoid explicit ; instead
- “Kernelized” algorithms access data only through
- gives kernel notion of smoothness
Reproducing Kernel Hilbert Space (RKHS)
- Ex: Gaussian RBF / exponentiated quadratic / squared exponential / …
- Some functions with small :
Maximum Mean Discrepancy (MMD)
The sup is achieved by
Estimating MMD
MMD as feature matching
- is the feature map for
- If , , then
the MMD is distance between means - Many kernels: infinite-dimensional
MMD-based tests
- If is characteristic, iff
- Efficient permutation testing for
- : converges in distribution
- : asymptotically normal
- Any characteristic kernel gives consistent test…eventually
- Need enormous if kernel is bad for problem
Classifier two-sample tests
![]()
- is the accuracy of on the test set
- Under , classification impossible:
- With where ,
get
Deep learning and deep kernels
- is one form of deep kernel
- Deep models are usually of the form
- With a learned
- If we fix , have with
- Same idea as NNGP approximation
- Generalize to a deep kernel:
Normal deep learning deep kernels
- Take
- Final function in will be
- With logistic loss: this is Platt scaling
![]()
So what?
- This definitely does not say that deep learning is (even approximately) a kernel method
- …despite what some people might want you to think
![]()
- We know theoretically deep learning can learn some things faster than any kernel method [see Malach+ ICML-21 + refs]
- But deep kernel learning ≠ traditional kernel models
- exactly like how usual deep learning ≠ linear models
Optimizing power of MMD tests
- Asymptotics of give us immediately that
, , are constants:
first term usually dominates
- Pick to maximize an estimate of
- Use from before, get from U-statistic theory
- Can show uniform convergence of estimator
Blobs dataset
![]()
Blobs results
![]()
Investigating a GAN on MNIST
![]()
CIFAR-10 vs CIFAR-10.1
Train on 1 000, test on 1 031, repeat 10 times. Rejection rates:ME | SCF | C2ST | MMD-O | MMD-D |
---|
0.588 | 0.171 | 0.452 | 0.316 | 0.744 |
Ablation vs classifier-based tests
| Cross-entropy | Max power |
---|
Dataset | Sign | Lin | Ours | Sign | Lin | Ours |
---|
Blobs | 0.84 | 0.94 | 0.90 | – | 0.95 | 0.99 |
---|
High- Gauss. mix. | 0.47 | 0.59 | 0.29 | – | 0.64 | 0.66 |
---|
Higgs | 0.26 | 0.40 | 0.35 | – | 0.30 | 0.40 |
---|
MNIST vs GAN | 0.65 | 0.71 | 0.80 | – | 0.94 | 1.00 |
---|
But…
- What if you don't have much data for your testing problem?
- Need enough data to pick a good kernel
- Also need enough test data to actually detect the difference
- Best split depends on best kernel's quality / how hard to find
- Don't know that ahead of time; can't try more than one
Meta-testing
- One idea: what if we have related problems?
- Similar setup to meta-learning:
(from Wei+ 2018)
Meta-testing for CIFAR-10 vs CIFAR-10.1
- CIFAR-10 has 60,000 images, but CIFAR-10.1 only has 2,031
- Where do we get related data from?
- One option: set up tasks to distinguish classes of CIFAR-10 (airplane vs automobile, airplane vs bird, ...)
One approach (MAML-like)
![]()
is, e.g., 5 steps of gradient descent
we learn the initialization, maybe step size, etc
![]()
This works, but not as well as we'd hoped…
Initialization might work okay on everything, not really adapt
Another approach: Meta-MKL
![]()
Inspired by classic multiple kernel learning
Only need to learn linear combination
on test task:
much easier
![]()
Theoretical analysis for Meta-MKL
- Same big-O dependence on test task size 😐
- But multiplier is much better:
based on number of meta-training tasks, not on network size - Coarse analysis: assumes one meta-tasks is “related” enough
- We compete with picking the single best related kernel
- Haven't analyzed meaningfully combining related kernels (yet!)
Results on CIFAR-10.1
![]()
But...
- Sometimes we know ahead of time that there are differences that we don't care about
- In the MNIST GAN criticism, initial attempt just picked out that the GAN outputs numbers that aren't one of the 256 values MNIST has
- Can we find a kernel that can distinguish from ,
but can't distinguish from ? - Also useful for fair representation learning
- e.g. can distinguish “creditworthy” vs not,
can't distinguish by race
High on one power, low on another
Choose with
- First idea:
- No good: doesn't balance power appropriately
- Second idea:
- Can estimate inside the optimization
- Better, but tends to “stall out” in minimizing
- Use previous on blocks, each of size
- Final estimator: average of each block's estimate
- Each block has previous asymptotics
- Central limit theorem across blocks
- Power is
MMD-B-Fair
- Choose as
- is the power of a test with blocks of size
- We don't actually use a block estimator computationally
- , have nothing to do with minibatch size
- Representation learning:
- Deep kernel is
- could be deep itself, with adversarial optimization
- For now, just Gaussians with different lengthscales
Adult
![]()
Shapes3D
: ![]() | : ![]() |
: ![]() | : ![]() |
Multiple targets / sensitive attributes
power on minibatchesRemaining challenges
- MMD-B-Fair:
- When and are very correlated
- For attributes with many values (use HSIC?)
- Meta-testing: more powerful approaches, better analysis
- When , can we tell how they're different?
- Methods so far: low-, and/or points w/ large critic value
- Avoid the need for data splitting (selective inference)
A good takeaway
Combining a deep architecture with a kernel machine that takes the higher-level learned representation as input can be quite powerful.
— Y. Bengio & Y. LeCun (2007), “Scaling Learning Algorithms towards AI”