\definecolor{cb1}{RGB}{76,114,176} \definecolor{cb2}{RGB}{221,132,82} \definecolor{cb3}{RGB}{85,168,104} \definecolor{cb4}{RGB}{196,78,82} \definecolor{cb5}{RGB}{129,114,179} \definecolor{cb6}{RGB}{147,120,96} \definecolor{cb7}{RGB}{218,139,195} \newcommand{\abs}[1]{\left\lvert #1 \right\rvert} \newcommand{\norm}[1]{\left\lVert #1 \right\rVert} \DeclareMathOperator*{\argmin}{argmin} \DeclareMathOperator{\bigO}{\mathcal{O}} \DeclareMathOperator{\D}{\mathcal{D}} \DeclareMathOperator*{\E}{\mathbb{E}} \DeclareMathOperator*{\Var}{Var} \DeclareMathOperator{\mean}{mean} \DeclareMathOperator{\sign}{sign} \DeclareMathOperator{\spn}{span} \newcommand{\tp}{^\mathsf{T}} \DeclareMathOperator{\tr}{tr} \newcommand{\ud}{\mathrm{d}} \DeclareMathOperator{\W}{\mathcal{W}} \DeclareMathOperator{\KL}{KL} \DeclareMathOperator{\JS}{JS} \DeclareMathOperator{\mmd}{MMD} \DeclareMathOperator{\smmd}{SMMD} \DeclareMathOperator{\hsic}{HSIC} \DeclareMathOperator{\mmdhat}{\widehat{MMD}} \DeclareMathOperator{\hsichat}{\widehat{HSIC}} \newcommand{\optmmd}[1][\Psic]{\operatorname{\mathcal{D}_\mathrm{MMD}^{#1}}} \DeclareMathOperator{\optmmdhat}{\hat{\mathcal{D}}_\mathrm{MMD}} \newcommand{\optsmmdp}{\operatorname{\mathcal{D}_\mathrm{SMMD}}} \newcommand{\optsmmd}[1][\Psic]{\operatorname{\mathcal{D}_\mathrm{SMMD}^{\SS,#1,\lambda}}} \newcommand{\ktop}{k_\mathrm{top}} \newcommand{\lip}{\mathrm{Lip}} \newcommand{\cH}{\mathcal{H}} \newcommand{\h}{\mathcal{H}} \newcommand{\R}{\mathbb{R}} \newcommand{\cP}[1]{{\color{cb1} #1}} \newcommand{\PP}{\cP{\mathbb P}} \newcommand{\pp}{\cP{p}} \newcommand{\X}{\cP{X}} \newcommand{\Xi}{\cP{X_i}} \newcommand{\Xp}{\cP{X'}} \newcommand{\x}{\cP{x}} \newcommand{\xp}{\cP{x'}} \newcommand{\nX}{\cP{n_X}} % kexpfam colors \newcommand{\nc}{{\color{#d62728}{n}}} \newcommand{\Xc}{{\color{#d62728}{X}}} \newcommand{\Mc}{{\color{#1f77b4}{M}}} \newcommand{\Yc}{{\color{#1f77b4}{Y}}} \newcommand{\mc}{{\color{#17becf}{m}}} \newcommand{\dc}{{\color{#2ca02c}{d}}} \newcommand{\Pdata}{\cP{\mathbb{P}_\mathrm{data}}} \newcommand{\cQ}[1]{{\color{cb2} #1}} \newcommand{\QQ}{\cQ{\mathbb Q}} \newcommand{\qq}{\cQ{q}} \newcommand{\qtheta}{\cQ{q_\theta}} \newcommand{\Y}{\cQ{Y}} \newcommand{\Yp}{\cQ{Y'}} \newcommand{\Yj}{\cQ{Y_j}} \newcommand{\nY}{\cQ{n_Y}} \newcommand{\y}{\cQ{y}} \newcommand{\yp}{\cQ{y'}} \newcommand{\thetac}{\cQ{\theta}} \newcommand{\vtheta}{\thetac} \newcommand{\Qtheta}{\QQ_\thetac} \newcommand{\Gtheta}{\cQ{G_\theta}} \newcommand{\cZ}[1]{{\color{cb5} #1}} \newcommand{\Z}{\cZ{Z}} \newcommand{\Zc}{\cZ{\mathcal Z}} \newcommand{\ZZ}{\cZ{\mathbb Z}} \newcommand{\cpsi}[1]{{\color{cb3} #1}} \newcommand{\psic}{\cpsi{\psi}} \newcommand{\Psic}{\cpsi{\Psi}} \newcommand{\Dpsi}{\cpsi{D_\psi}} \newcommand{\SS}{\cpsi{\mathbb{S}}} \newcommand{\Xtilde}{\cpsi{\tilde{X}}} \newcommand{\Xtildep}{\cpsi{\tilde{X}'}}

Are these datasets the same?
Efficient, fair two-sample testing
with learned kernels

Danica J. Sutherland(she/they)
University of British Columbia (UBC) / Alberta Machine Intelligence Institute (Amii)
Hsiao-Yu (Fish) Tung
Heiko Strathmann
Soumyajit De
Aaditya Ramdas
Alex Smola
Arthur Gretton
Feng Liu
Wenkai Xu
Jie Lu
Guangquan Zhang
Arthur Gretton
Feng Liu
Wenkai Xu
Jie Lu
new!
Namrata Deka

Toronto Womxn in Data Science - April 27, 2022

(Swipe or arrow keys to move through slides; m for a menu to jump; ? to show more.)

Data drift

  • Textbook machine learning:
    • Train on i.i.d. samples from some distribution, \Xi \sim \PP
    • If it works on \X , probably works on new samples from \PP
  • Really:
    • Train on \X
    • Pretend there's a \PP that \X is an i.i.d. sample from
    • If it works on \X , maybe it sorta works on \PP
    • Deploy on something that's maybe a distribution \QQ
      • which might be sort of like \PP
      • but probably changes over time…

This talk

Based on samples \{ \Xi \} \sim \PP and \{ \Yj \} \sim \QQ :

  • How is \PP different from \QQ ?
  • Is \PP close enough to \QQ for our model?
  • Is \PP = \QQ ?

Two-sample testing

  • Given samples from two unknown distributions \X \sim \PP \qquad \Y \sim \QQ
  • Question: is \PP = \QQ ?
  • Hypothesis testing approach: H_0: \PP = \QQ \qquad H_1: \PP \ne \QQ
  • Reject null hypothesis H_0 if test statistic \hat T(\X, \Y) > c_\alpha
  • Do smokers/non-smokers get different cancers?
  • Do Brits have the same friend network types as Americans?
  • When does my laser agree with the one on Mars?
  • Are storms in the 2000s different from storms in the 1800s?
  • Does presence of this protein affect DNA binding? [MMDiff2]
  • Do these dob and birthday columns mean the same thing?
  • Does my generative model \Qtheta match \Pdata ?
  • Independence testing: is P(\X, \Y) = P(\X) P(\Y) ?

What's a hypothesis test again?

2022-02-15T15:05:57.648255 image/svg+xml Matplotlib v3.5.1, https://matplotlib.org/

Permutation testing to find c_\alpha

Need \Pr_{H_0}\left( \hat T(\X, \Y) > c_\alpha \right) \le \alpha

X_1
X_2
X_3
X_4
X_5
Y_1
Y_2
Y_3
Y_4
Y_5

c_\alpha : 1-\alpha th quantile of \left\{ \fragment[1]{\hat T(\cP{\tilde X_1}, \cQ{\tilde Y_1}), \;} \fragment[2]{\hat T(\cP{\tilde X_2}, \cQ{\tilde Y_2}), \;} \fragment[3]{\cdots} \right\}

Classifier two-sample tests

  • We need a \hat T(\X, \Y) that's large if \PP \ne \QQ , small if \PP = \QQ
  • Can choose \hat T(\X, \Y) as the accuracy of f on the test set
    • If \PP = \QQ , classification impossible: \hat T \sim \mathrm{Binom}(n, \frac12)
  • Usually better: \displaystyle \hat T(\X, \Y) = \operatorname*{mean}_{\cP{x} \in \cP{X_\mathit{test}}}[ f(\x) ] - \operatorname*{mean}_{\y \in \cQ{Y_\mathit{test}}}[ f(\y) ]
    • If \PP = \QQ , \hat T \to normal (permuting on test set is better)

A more general framework

  • C2ST-L: \displaystyle \hat T(\X, \Y) = \operatorname*{mean}_{\cP{x} \in \cP{X_\mathit{test}}}[ f(\cP{x}) ] - \operatorname*{mean}_{\cQ{y} \in \cQ{Y_\mathit{test}}}[ f(\cQ{y}) ]
    • f(x) \in \R is a classifier's logit:
      log probability \cP{x} is from \PP rather than \QQ , plus const
  • Basically same: \displaystyle \hat T(\X, \Y) = \abs{ \operatorname*{mean}_{\cP{x} \in \cP{X_\mathit{test}}}[ f(\cP{x}) ] - \operatorname*{mean}_{\cQ{y} \in \cQ{Y_\mathit{test}}}[ f(\cQ{y}) ] }
  • What if we use more general features of the data? \hat T(\X, \Y) = \norm{ \operatorname*{mean}_{\cP{x} \in \cP{X_\mathit{test}}}[ \varphi(\cP{x}) ] - \operatorname*{mean}_{\cQ{y} \in \cQ{Y_\mathit{test}}}[ \varphi(\cQ{y}) ] }

Difference between mean embeddings

\begin{align*} \hat T&(\X, \Y)^2 = \norm{ \operatorname*{mean}_{\cP{x} \in \cP{X_\mathit{test}}}[ \varphi(\cP{x}) ] - \operatorname*{mean}_{\cQ{y} \in \cQ{Y_\mathit{test}}}[ \varphi(\cQ{y}) ] }^2 \\&\fragment[0]{ = \fragment[2][highlight-current-red]{\operatorname*{mean}_{\cP{x} \ne \cP{x'}}[ \varphi(\cP{x}) \cdot \varphi(\cP{x'}) ] } - 2 \fragment[3][highlight-current-red]{\operatorname*{mean}_{\cP{x}, \cQ{y}}[ \varphi(\cP{x}) \cdot \varphi(\cQ{y}) ] } + \fragment[4][highlight-current-red]{\operatorname*{mean}_{\cQ{y} \ne \cQ{y'}}[ \varphi(\cQ{y}) \cdot \varphi(\cQ{y'}) ] } } \end{align*}

Only use data through \varphi(\x) \cdot \varphi(\y) = k(\x, \y) : can kernelize!

K_{\X\X}

1.00.20.6
0.21.00.5
0.60.51.0

K_{\Y\Y}

1.00.80.7
0.81.00.6
0.70.61.0

K_{\X\Y}

0.30.10.2
0.20.30.3
0.20.10.4

What's a kernel again?

  • Linear classifiers: \hat y(x) = \sign(f(x)) , f(x) = w\tp \left( x, 1 \right)
  • Use a “richer” x : f(x) = w\tp \left( x, x^2, 1\right) = w\tp \varphi(x)
  • Can avoid explicit \varphi(x) ; instead k(x, y) = \langle \varphi(x), \varphi(y) \rangle_{\cH}
  • “Kernelized” algorithms access data only through k(x, y) f(x) = \langle w, \varphi(x) \rangle_\cH = \sum_{i=1}^n \alpha_i k(X_i, x)
  • \lVert f \rVert_\cH = \sqrt{\alpha\tp K \alpha} gives kernel notion of smoothness
2022-01-11T16:32:48.413835 image/svg+xml Matplotlib v3.5.1, https://matplotlib.org/

Reproducing Kernel Hilbert Space (RKHS)

  • Ex: Gaussian RBF / exponentiated quadratic / squared exponential / … k(x, y) = \exp\left( - \frac{\norm{x - y}^2}{2 \sigma^2} \right)
  • Some functions with small \lVert f \rVert_\cH :
2022-02-15T17:07:41.886621 image/svg+xml Matplotlib v3.5.1, https://matplotlib.org/

Maximum Mean Discrepancy (MMD)

\mmd_k(\PP, \QQ) = \sup_{\lVert f \rVert_\cH \le 1} \E_{\X \sim \PP}[f(\X)] - \E_{\Y \sim \QQ}[f(\Y)]

The sup is achieved by f(t) \propto \mathbb E_{\x \sim \PP}[k(\x, t)] - \mathbb E_{\y \sim \QQ}[k(\y, t)]

\mmd^2(\PP, \QQ) = \E_{\substack{\x, \xp \sim \PP\\\y, \yp \sim \QQ}}\left[ k(\x, \xp) + k(\y, \yp) - 2 k(\x, \y) \right]

2022-04-27T08:52:24.462172 image/svg+xml Matplotlib v3.5.1, https://matplotlib.org/

MMD-based tests

  • If k is characteristic, \mmd(\PP, \QQ) = 0 iff \PP = \QQ
  • Efficient permutation testing for \mmdhat(\X, \Y)
    • H_0 : n \mmdhat^2 converges in distribution
    • H_1 : \sqrt n ( \mmdhat^2 - \mmd^2 ) asymptotically normal
  • Any characteristic kernel gives consistent test…eventually
  • Need enormous n if kernel is bad for problem

Deep learning and deep kernels

  • C2ST-L is basically MMD with k(x, y) = f(x) f(y)
    • f is a (learned) deep net – a learned kernel
  • We can generalize some more to deep kernels: k_\psic(x, y) = \kappa\left( \phi_\psic(x), \phi_\psic(y) \right)
    • \phi is a deep net, maps data points to \R^D
    • \kappa is a simple kernel on \R^D
    • \kappa(u, v) = u \cdot v gives MMD as \norm{\E \phi(\cP x) - \E \phi(\cQ y)}

Optimizing power of MMD tests

  • Asymptotics of \mmdhat^2 give us immediately that \Pr_{H_1}\left( n \mmdhat^2 > c_\alpha \right) \approx \Phi\left( \frac{\sqrt n \mmd^2}{\sigma_{H_1}} - \frac{c_\alpha}{\sqrt n \sigma_{H_1}} \right) \mmd , \sigma_{H_1} , c_\alpha are constants: first term usually dominates
  • Pick k to maximize an estimate of \mmd^2 / \sigma_{H_1}
  • Use \mmdhat from before, get \hat\sigma_{H_1} from U-statistic theory
  • Can show uniform \mathcal O_P(n^{-\frac13}) convergence of estimator

Blobs dataset

Blobs kernels

Blobs results

Investigating a GAN on MNIST

CIFAR-10 vs CIFAR-10.1

Train on 1 000, test on 1 031, repeat 10 times. Rejection rates:
MESCFC2STMMD-OMMD-D
0.5880.1710.4520.3160.744

Ablation vs classifier-based tests

Cross-entropyMax power
DatasetSignLinOursSignLinOurs
Blobs0.840.940.900.950.99
High- d Gauss. mix.0.470.590.290.640.66
Higgs0.260.400.350.300.40
MNIST vs GAN0.650.710.800.941.00

But…

  • What if you don't have much data for your testing problem?
  • Need enough data to pick a good kernel
  • Also need enough test data to actually detect the difference
  • Best split depends on best kernel's quality / how hard to find
    • Don't know that ahead of time; can't try more than one

Meta-testing

  • One idea: what if we have related problems?
  • Similar setup to meta-learning:(from Wei+ 2018)

Meta-testing for CIFAR-10 vs CIFAR-10.1

  • CIFAR-10 has 60,000 images, but CIFAR-10.1 only has 2,031
  • Where do we get related data from?
  • One option: set up tasks to distinguish classes of CIFAR-10 (airplane vs automobile, airplane vs bird, ...)

One approach (MAML-like)

A_\theta is, e.g., 5 steps of gradient descent

we learn the initialization, maybe step size, etc

This works, but not as well as we'd hoped…
Initialization might work okay on everything, not really adapt

Another approach: Meta-MKL

Inspired by classic multiple kernel learning

Only need to learn linear combination \beta_i on test task:
much easier

Theoretical analysis for Meta-MKL

  • Same big-O dependence on test task size 😐
  • But multiplier is much better:
    based on number of meta-training tasks, not on network size
  • Coarse analysis: assumes one meta-tasks is “related” enough
    • We compete with picking the single best related kernel
    • Haven't analyzed meaningfully combining related kernels (yet!)

Results on CIFAR-10.1

But...

  • Sometimes we know ahead of time that there are differences that we don't care about
    • In the MNIST GAN criticism, initial attempt just picked out that the GAN outputs numbers that aren't one of the 256 values MNIST has
  • Can we find a kernel that can distinguish \PP^t from \QQ^t ,
    but can't distinguish \PP^s from \QQ^s ?
  • Also useful for fair representation learning
    • e.g. can distinguish “creditworthy” vs not,
      can't distinguish by race

High on one power, low on another

Choose k with \min_k \rho_k^s - \rho_k^t

  • First idea: \displaystyle \rho = \frac{(\mmd)^2}{\sigma_{H_1}}
    • No good: doesn't balance power appropriately
  • Second idea: \displaystyle \rho = \Phi\left( \frac{\sqrt{n} (\mmd)^2 - c_\alpha}{\sigma_{H_1}} \right)
    • Can estimate c_\alpha inside the optimization
    • Better, but tends to “stall out” in minimizing \rho_k^s

Block estimator [Zaremba+ NeurIPS-13]

  • Use previous \mmdhat on b blocks, each of size B
  • Final estimator: average of each block's estimate
    • Each block has previous asymptotics
    • Central limit theorem across blocks
  • Power is \displaystyle \rho = \Phi\left( \sqrt{b B} \frac{\mmd^2}{\sigma_{H_1}^2} - \Phi^{-1}(1 - \alpha) \right)

MMD-B-Fair

  • Choose k as \min_k \rho_k^s - \rho_k^t
    • \rho is the power of a test with b blocks of size B
    • We don't actually use a block estimator computationally
    • b , B have nothing to do with minibatch size
  • Representation learning: \min_\phi \max_\kappa \rho^s_{\kappa \circ \phi} - \rho^t_{\kappa \circ \phi}
    • Deep kernel is [\kappa \circ \phi](x, y) = \kappa(\phi(x), \phi(y))
    • \kappa could be deep itself, with adversarial optimization
    • For now, just Gaussians with different lengthscales

Adult

Shapes3D

\PP^t : \QQ^t :
\PP^s : \QQ^s :

Multiple targets / sensitive attributes

\displaystyle \max_k \frac{1}{\lvert \mathcal T \rvert} \sum_{t \in \mathcal T} \rho^t_k - \frac{1}{\lvert \mathcal S \rvert} \sum_{s \in \mathcal S} \rho^s_k

power on minibatches

Remaining challenges

  • MMD-B-Fair fails when s and t are very correlated
  • Meta-testing: more powerful approaches, better analysis
  • When \PP \ne \QQ , can we tell how they're different?
    • Methods so far: low- d , and/or points w/ large critic value
  • Avoid the need for data splitting (selective inference)
  • Online detection of data shifts

A good takeaway

Combining a deep architecture with a kernel machine that takes the higher-level learned representation as input can be quite powerful.

— Y. Bengio & Y. LeCun (2007), “Scaling Learning Algorithms towards AI