\definecolor{cb1}{RGB}{76,114,176}
\definecolor{cb2}{RGB}{221,132,82}
\definecolor{cb3}{RGB}{85,168,104}
\definecolor{cb4}{RGB}{196,78,82}
\definecolor{cb5}{RGB}{129,114,179}
\definecolor{cb6}{RGB}{147,120,96}
\definecolor{cb7}{RGB}{218,139,195}
\newcommand{\abs}[1]{\left\lvert #1 \right\rvert}
\newcommand{\norm}[1]{\left\lVert #1 \right\rVert}
\DeclareMathOperator*{\argmin}{argmin}
\DeclareMathOperator{\bigO}{\mathcal{O}}
\DeclareMathOperator{\Cov}{Cov}
\DeclareMathOperator{\E}{\mathbb{E}}
\newcommand{\indic}{\mathbb{I}}
\newcommand{\R}{\mathbb{R}}
\newcommand{\tp}{^\mathsf{T}}
\DeclareMathOperator{\Tr}{Tr}
\DeclareMathOperator{\Var}{Var}
\DeclareMathOperator{\NTK}{NTK}
\DeclareMathOperator{\eNTK}{eNTK}
\DeclareMathOperator{\pNTK}{pNTK}
\newcommand{\lin}{\mathit{lin}}
\newcommand{\xt}{{\color{cb2} \tilde x}}
\newcommand{\exi}{{\color{cb1} x_i}}
\newcommand{\yi}{{\color{cb1} y_i}}
\newcommand{\Xs}{{\color{cb1} \mathbf X}}
\newcommand{\ys}{{\color{cb1} \mathbf y}}
\newcommand{\w}{\mathbf{w}}
\newcommand{\tar}{\mathit{tar}}
\newcommand{\ptari}{{\color{cb1} p^\tar_i}}
\newcommand{\pstari}{{\color{cb1} p^*_i}}
A Defense of (Empirical) Neural Tangent Kernels Danica J. Sutherland (she) University of British Columbia (UBC) / Alberta Machine Intelligence Institute (Amii)
Yi Ren
Shangmin Guo
Yi Ren
Shangmin Guo
Wonho Bae
Active learning
[NeurIPS-22 ] Mohamad Amin Mohamadi
Wonho Bae
Mohamad Amin Mohamadi
Wonho Bae
University of Michigan AI Seminar - March 23, 2023
Slides available at djsutherland.ml/slides/entk-umich
(Swipe or arrow keys to move through slides;
m
for a menu to jump;
?
to show more.)PDF version
HTML version
(A lot of) this talk in a tweet:
One path to NTKs(Taylor's version) “Learning path” of a model's predictions:
f_t(\xt)
for some fixed
\xt
as params
\w_t
change Let's start with “plain” SGD on
\frac1N \sum_{i=1}^N \ell_{\yi}(f(\exi))
:
\begin{align*}
\!\!\!\!\!\!\!\!\!\!\!
\underbrace{f_{t+1}(\xt) - f_t(\xt)}_{k \times 1}
&= \underbrace{\left( \nabla_\w f(\exi) \rvert_{\w_t} \right)}_{k \times p}
\,
\underbrace{(w_{t+1} - w_t)}_{p \times 1}
{\color{gray}{} + \bigO\left( \norm{\w_{t+1} - \w_t}^2 \right)}
\\&\fragment[3]{ {}
= \left( \nabla_\w f(\xt) \rvert_{\w_t} \right)
\left( - \eta \nabla_\w \ell_\yi(f(\exi)) \rvert_{\w_t} \right)\tp
{\color{gray}{} + \bigO(\eta^2)}
}
\\&\fragment[4]{ {}
= - \eta \,
\underbrace{ \left( \nabla_\w f(\xt) \rvert_{\w_t} \right) }_{k \times p} \,
\underbrace{ \left( \nabla_\w f(\exi) \rvert_{\w_t} \right)\tp }_{p \times k} \,
\underbrace{ \ell_\yi'(f_t(\exi)) }_{k \times 1}
{\color{gray}{} + \bigO(\eta^2)}
}
\\&\fragment[5]{ {}
= - \eta \eNTK_{\w_t}(\xt, \exi) \, \ell'_\yi(f_t(\exi))
{\color{gray}{} + \bigO(\eta^2)}
}
\end{align*}
Defined
\eNTK_\w(\xt, \exi) =
\left( \nabla_\w f(\xt; \w) \right) \,
\left( \nabla_\w f(\exi; \w) \right)\tp
\in \R^{k \times k}
(\exi, \yi)
step barely changes
\xt
prediction if
\eNTK_{\w_t}(\xt, \exi)
is small
\ell_y'(\hat y) = \hat y - y
for square loss,
\hat y_y - \log \sum_{j=1}^k \exp(\hat y_j)
for cross-entropyFull-batch GD:
“stacking things up”,
\begin{align*}
f_{t+1}(\xt) - f_t(\xt)
&= - \frac{\eta}{N} \sum_{i=1}^N \eNTK_{\w_t}(\xt, \exi) \ell'_{\yi}(f_t(\exi))
{\color{gray}{} + \bigO(\eta^2)}
\\&\fragment[1]{ {}
= - \frac{\eta}{N}
\underbrace{\eNTK_{\w_t}(\xt, \Xs)}_{k \times k N}
\, \underbrace{L_\ys'(f_t(\Xs))}_{k N \times 1}
{\color{gray}{} + \bigO(\eta^2)}
}
\end{align*}
Observation I: If
f
is “wide enough”
with any usual architecture+init* [Yang+Litwin 2021 ] ,
\eNTK(\cdot, \Xs)
is roughly constant through training For square loss,
L'_{\yi}(f_t(\Xs)) = f_t(\Xs) - \yi
: dynamics agree with kernel regression!
f_t(\xt) \xrightarrow{t \to \infty} \eNTK_{\w_0}(\xt, \Xs) \eNTK_{\w_0}(\Xs, \Xs)^{-1} (\ys - f_0(\Xs)) + f_0(\xt)
Observation II: As
f
becomes “infinitely wide”
with any usual architecture+init* [Yang 2019 ] ,
\eNTK_{\w_0}(x_1, x_2) \xrightarrow{a.s.} \NTK(x_1, x_2)
,
independent of the random
\w_0
Infinite NTKs are great Infinitely-wide neural networks have very simple behaviour !No need to worry about bad local minima, optimization complications, … Understanding “implicit bias” of wide nets
\approx
understanding NTK norm of functions Can compute
\NTK
exactly for many architectures A great kernel for many kernel methods! But (infinite) NTKs aren't “the answer” Computational expense:Poor scaling for large-data problems: typically
n^2
memory and
n^2
to
n^3
computationCIFAR-10 has
n = 50\,000
,
k = 10
: an
nk \times nk
matrix of float64 s is 2 terabytes! ILSVRC2012 has
n \approx 1\,200\,000
,
k = 1\,000
: 11.5 million terabytes (exabytes) For deep/complex models (especially CNNs), each pair very slow / memory-intensive Practical performance:Typically performs worse than GD for “non-small-data” tasks (MNIST and up) Theoretical limitations:NTK “doesn't do feature learning”: We now know many problems where gradient descent on an NN
\gg
any kernel methodCases where GD error
\to 0
, any kernel is barely better than random [Malach+ 2021 ] What can we learn from empirical NTKs? In this talk:
As a theoretical-ish tool for local understanding:Fine-grained explanation for early stopping in knowledge distillation How you should fine-tune models As a practical tool for approximating “lookahead” in active learning Plus: efficiently approximating
\eNTK
s for large output dimensions
k
, with guarantees Better supervisory signal implies better learning Classification: target is
L_P(f) = \E_{(x, y)} \ell_y(f(x)) = \E_x \E_{y \mid x} \ell_y(f(x))
Normally: see
\{(\exi, \yi)\}
, minimize
(
\vec\ell(\hat y) \in \R^k
is vector of losses for all possible labels)
\!\!\!
L(f)
= \frac1N \sum_{i=1}^N \ell_\yi(f(\exi))
\fragment[1]{= \frac1N \sum_{i=1}^N \sum_{c=1}^k \indic(\yi = c) \ell_c(f(\exi)) }
\fragment[2]{= \frac1N \sum_{i=1}^n e_{\yi} \cdot \vec\ell(f(\exi))}
Potentially better scheme:
see
\{(\exi, \ptari)\}
,
minimize
\displaystyle L^\tar(f) = \frac1N \sum_{i=1}^N \ptari \cdot \vec\ell(f(\exi))
Can reduce variance if
\ptari \approx \pstari
, the true conditional probabilities Knowledge distillation Process:Train a teacher
f^\mathit{teacher}
on
\{ (\exi, \yi) \}
with standard ERM,
L(f)
Train a student on
\{ (\exi, {\color{cb4} f^\mathit{teacher}}(\exi)) \}
with
L^\tar
Usually
\color{cb5} f^\mathit{student}
is “smaller” than
\color{cb4} f^\mathit{teacher}
But “self-distillation” (using the same architecture), often
\color{cb5} f^\mathit{student}
outperforms
\color{cb4} f^\mathit{teacher}
! One possible explanation:
\color{cb4} f^\mathit{teacher}(\exi)
is closer to
\pstari
than sampled
\yi
But why would that be? Zig-Zagging behaviour in learning Plots of (three-way) probabilistic predictions:
\boldsymbol\times
shows
\pstari
,
\boldsymbol\times
shows
\yi
eNTK explains it Let
q_t(\xt) = \operatorname{softmax}(f_t(\xt)) \in \R^k
; for cross-entropy loss, one SGD step gives us
q_{t+1}(\xt) - q_t(\xt)
= \eta \; \mathcal A_t(\xt) \, \eNTK_{\w_t}(\xt, \exi) \, (\ptari - q_t(\exi))
{\color{gray}{} + \bigO(\eta^2)}
\mathcal A_t(\xt) = \operatorname{diag}(q_t(\xt)) - q_t(\xt) q_t(\xt)\tp
is the covariance of a
\operatorname{Categorical}(q_t(\xt))
Improves distillation (esp. with noisy labels) to take moving average of
q_t(\exi)
as
\ptari
What can we learn from empirical NTKs? In this talk:
As a theoretical-ish tool for local understanding:Fine-grained explanation for early stopping in knowledge distillation How you should fine-tune models As a practical tool for approximating “lookahead” in active learning Plus: efficiently approximating
\eNTK
s for large output dimensions
k
, with guarantees Fine-tuning Pretrain, re-initialize a random head, then adapt to a downstream task.
Two phases:Head probing : only update the head
g(z)
Fine-tuning : update head
g(z)
and backbone
z = f(x)
together If we only fine-tune: noise from random head might break our features! If we head-probe to convergence: might already fit training data and not change features! How much do we change our features? Same kind of decomposition with backbone features
z = f(x)
, head
q = \operatorname{softmax}(g(z))
:
z_{t+1}(\xt) - z_t(\xt)
= \frac{\eta}{N} \sum_{i=1}^N
\underbrace{\eNTK_{\w_t}^{f}(\xt, \exi)}_\text{eNTK of backbone}
\, \underbrace{\left(\nabla_z q_t(\exi)\right)\tp_\phantom{\w_t}\!\!\!}_\text{direction of head}
\, \underbrace{(e_{\yi} - q_t(\exi))_\phantom{\w_t}\!\!\!\!\!\!\!}_\text{“energy"}
{\color{gray}{} + \bigO(\eta^2)}
If initial “energy”, e.g.
\E_{\exi,\yi} \norm{e_{\yi} - p_0(\exi)}
, is small, features don't change much If we didn't do any head probing, “direction” is very random, especially if
g
is rich Specializing to simple linear-linear model, can get insights about trends in
z
Recommendations from paper:Early stop during head probing (ideally, try multiple lengths for downstream task) Label smoothing can help; so can more complex heads, but be careful With random head (no head probing), generalized cross-validation on eNTK model gives excellent estimate of downstream loss What can we learn from empirical NTKs? In this talk:
As a theoretical-ish tool for local understanding:Fine-grained explanation for early stopping in knowledge distillation How you should fine-tune models As a practical tool for approximating “lookahead” in active learning Plus: efficiently approximating
\eNTK
s for large output dimensions
k
, with guarantees Pool-based active learning Pool of unlabeled data available; ask for annotations of the “most informative” points Kinds of acquisition functions used for deep active learning:Uncertainty-based: maximum entropy, BALD Representation-based: BADGE, LL4AL Another kind used for simpler models: lookahead criteria“How much would my model change if I saw
\exi
with label
\yi
?” Too expensive for deep learning…unless you use a local approximation to retraining Approximate retraining with local linearization Given
f_{\mathcal L}
trained on labeled data
\mathcal L
, approximate
f_{\mathcal L \cup \{(\exi, \yi)\}}
with local linearization
\!\!\!\!\!
f_{\mathcal L \cup \{(\exi, \yi)\}}(\xt)
\approx
f_{\mathcal L}(\xt) +
\eNTK_{\w_{\mathcal L}}\left(\xt, \begin{bmatrix} \Xs_{\mathcal L} \\ \exi \end{bmatrix} \right)
\eNTK_{\w_{\mathcal L}}\left(\begin{bmatrix} \Xs_{\mathcal L} \\ \exi \end{bmatrix}, \begin{bmatrix} \Xs_{\mathcal L} \\ \exi \end{bmatrix} \right)^{-1}
\left( \begin{bmatrix} \ys_{\mathcal L} \\ \yi \end{bmatrix} - f_{\mathcal L}\left(\begin{bmatrix} \Xs_{\mathcal L} \\ \exi \end{bmatrix} \right) \right)
Rank-one updates for efficient computation: schema We prove this is exact for infinitely wide networks
f_0 \to f_{\mathcal L} \to f_{\mathcal L \cup \{(\exi, \yi)\}}
agrees with direct
f_0 \to f_{\mathcal L \cup \{(\exi, \yi)\}}
Local approximation with eNTK “should” work much more broadly than “NTK regime”Much faster than SGD Much more effective than infinite NTK and one-step SGD Matches/beats state of the art Downside: usually more computationally expensive (especially memory)
Enables new interaction modes “Sequential” querying: incorporate true new labels one-at-a-time instead of batchOnly update
\eNTK
occasionally Makes sense when labels cost $ but are fast; other deep AL methods need to retrain What can we learn from empirical NTKs? In this talk:
As a theoretical-ish tool for local understanding:Fine-grained explanation for early stopping in knowledge distillation How you should fine-tune models As a practical tool for approximating “lookahead” in active learning Plus: efficiently approximating
\eNTK
s for large output dimensions
k
, with guarantees Approximating empirical NTKs I hid something from you on active learning (and Wei/Hu/Steinhardt fine-tuning) results… With
k
classes,
\eNTK(\Xs, \Xs) \in \R^{k N \times k N}
– potentially very big But actually, we know that
\E_{\w} \eNTK_\w(x_1, x_2)
is diagonal for most architecturesLet
\displaystyle
\pNTK_\w(x_1, x_2) =
\underbrace{\left[ \nabla_\w f_1(x_1) \right] }_{1 \times p}
\underbrace{\left[ \nabla_\w f_1(x_2) \right) \tp}_{p \times 1}
% \underbrace{\left[ \nabla_\w\left( \frac{1}{\sqrt k} \sum_{j=1}^k f_j(x_1) \right) \right]}_{1 \times p}
% \underbrace{\left[ \nabla_\w\left( \frac{1}{\sqrt k} \sum_{j=1}^k f_j(x_2) \right) \right]\tp}_{p \times 1}
.
\E_\w \eNTK_\w(x_1, x_2) = \E_w\left[ \pNTK_\w(x_1, x_2) \right] I_k
.
\pNTK(\Xs, \Xs) \in \R^{N \times N}
(no
k
!) Can also use “sum of logits ”
\frac{1}{\sqrt k} \sum_{j=1}^k f_j
instead of just “first logit ”
f_1
Lots of work (including above) has used
\pNTK
instead of
\eNTK
Often without saying anything; sometimes doesn't seem like they know they're doing it Can we justify this more rigorously? pNTK motivation Say
f(x) = V \phi(x)
,
\phi(x) \in \R^h
, and
V \in \R^{k \times h}
has rows
v_j \in \R^h
with iid entries If
v_{j,i} \sim \mathcal N(0, \sigma^2)
, then
v_1
and
\frac{1}{\sqrt k} \sum_{j=1}^k v_j
have same distribution
\begin{align*}
\fragment[1]{
\eNTK_\w(x_1, x_2)_{jj'}
}&\fragment[1]{ {}
= v_j\tp \eNTK^\phi_{\w \setminus V}(x_1, x_2) \, v_{j'} + \indic(j = j') \phi(x_1)\tp \phi(x_2)
}
\\
\fragment[2]{\pNTK_\w(x_1, x_2)}
&\fragment[2]{ {}
= {\color{cb4} v_1\tp} \eNTK^\phi_{\w \setminus V}(x_1, x_2) \, {\color{cb4} v_1} + \phi(x_1)\tp \phi(x_2)
}
\end{align*}
We want to bound difference
\eNTK(x_1, x_2) - \pNTK(x_1, x_2) I_k
Want
v_1\tp A v_1
and
v_j\tp A v_j
to be close, and
v_j\tp A v_{j'}
small, for random
v
and fixed
A
Using Hanson-Wright :
\displaystyle \frac{\norm{\eNTK - \pNTK I}_F}{\norm{\eNTK}_F} \le \frac{\norm{\eNTK^\phi}_F + 4 \sqrt h}{\Tr(\eNTK^\phi)} k \log \frac{2 k^2}{\delta}
Fully-connected ReLU nets at init., fan-in mode: numerator
\bigO(h \sqrt h)
, denom
\Theta(h^2)
pNTK's Frobenius error Same kind of theorem / empirical results for largest eigenvalue, and empirical results for
\lambda_\min
, condition number
Kernel regression with pNTK Reshape things to handle prediction appropriately:
\begin{align*}
\underbrace{f_{\eNTK}(\xt)}_{k \times 1} &=
\underbrace{f_0(\xt)}_{k \times 1}
+ \phantom{\Big(}
\underbrace{\eNTK_{\w_0}(\xt, \Xs)}_{k \times k N}
\,\underbrace{\eNTK_{\w_0}(\Xs, \Xs)^{-1}}_{k N \times k N}
\,\underbrace{(\ys - f_0(\Xs))}_{kN \times 1}
\\
\underbrace{f_{\pNTK}(\xt)}_{k \times 1} &=
\underbrace{f_0(\xt)}_{k \times 1}
+ \Big(
\underbrace{\pNTK_{\w_0}(\xt, \Xs)}_{1 \times N}
\,\underbrace{\pNTK_{\w_0}(\Xs, \Xs)^{-1}}_{N \times N}
\,\underbrace{(\ys - f_0(\Xs))}_{N \times k}
\Big)\tp
\end{align*}
We have
\norm{f_{\eNTK}(\xt) - f_{\pNTK}(\xt)} = \bigO(\frac{1}{\sqrt h})
againIf we add regularization, need to “scale”
\lambda
between the two Kernel regression with pNTK pNTK speed-up pNTK speed-up on active learning task pNTK for full CIFAR-10 regression
\eNTK(\Xs, \Xs)
on CIFAR-10: 1.8 terabytes of memory
\pNTK(\Xs, \Xs)
on CIFAR-10: 18 gigabytes of memoryWorse than infinite NTK for FCN/ConvNet (where they can be computed, if you try hard) Way worse than SGD Recap eNTK is a good tool for intuitive understanding of the learning process
eNTK is practically very effective at “lookahead” for active learning
You should probably use pNTK instead of eNTK for high-dim output problems: