\definecolor{cb1}{RGB}{76,114,176}
\definecolor{cb2}{RGB}{221,132,82}
\definecolor{cb3}{RGB}{85,168,104}
\definecolor{cb4}{RGB}{196,78,82}
\definecolor{cb5}{RGB}{129,114,179}
\definecolor{cb6}{RGB}{147,120,96}
\definecolor{cb7}{RGB}{218,139,195}
\newcommand{\abs}[1]{\left\lvert #1 \right\rvert}
\newcommand{\norm}[1]{\left\lVert #1 \right\rVert}
\DeclareMathOperator*{\argmin}{argmin}
\DeclareMathOperator{\bigO}{\mathcal{O}}
\DeclareMathOperator{\Cov}{Cov}
\DeclareMathOperator{\E}{\mathbb{E}}
\newcommand{\indic}{\mathbb{I}}
\newcommand{\R}{\mathbb{R}}
\DeclareMathOperator{\Softmax}{Softmax}
\newcommand{\tp}{^\mathsf{T}}
\DeclareMathOperator{\Tr}{Tr}
\DeclareMathOperator{\Var}{Var}
\DeclareMathOperator{\NTK}{NTK}
\DeclareMathOperator{\eNTK}{\mathcal K}
\DeclareMathOperator{\pNTK}{pNTK}
\newcommand{\lin}{\mathit{lin}}
\newcommand{\xt}{{\color{cb2} \tilde x}}
\newcommand{\yt}{{\color{cb2} \tilde y}}
\newcommand{\exi}{{\color{cb1} x_i}}
\newcommand{\yi}{{\color{cb1} y_i}}
\newcommand{\yip}{{\color{cb1} y_i^+}}
\newcommand{\yim}{{\color{cb1} y_i^-}}
\newcommand{\Xs}{{\color{cb1} \mathbf X}}
\newcommand{\ys}{{\color{cb1} \mathbf y}}
\newcommand{\f}{\mathbf{f}}
\newcommand{\w}{\mathbf{w}}
\newcommand{\y}{\mathbf{y}}
\newcommand{\z}{\mathbf{z}}
\newcommand{\st}{\boldsymbol{\chi}}
\newcommand{\sti}{{\color{cb1} \chi_i}}
\newcommand{\stip}{{\color{cb1} \chi_i^+}}
\newcommand{\stim}{{\color{cb1} \chi_i^-}}
\newcommand{\stt}{{\color{cb2} \tilde\chi}}
\newcommand{\tar}{\mathit{tar}}
\newcommand{\ptari}{{\color{cb1} p^\tar_i}}
\newcommand{\pstari}{{\color{cb1} p^*_i}}
Local Learning Dynamics Help Explain (Post-)Training Behaviour Danica J. Sutherland (she) University of British Columbia (UBC) / Alberta Machine Intelligence Institute (Amii)
Knowl. dist. analysis
[ICLR-22 ] Yi Ren
Shangmin Guo
Finetuning analysis
[ICLR-23 ] Yi Ren
Shangmin Guo
Wonho Bae
Active learning
[NeurIPS-22 ] M. Amin Mohamadi
Wonho Bae
M. Amin Mohamadi
Wonho Bae
Wenlong Deng
Yi Ren
Muchen Li
Xiaoxiao Li
Christos Thrampoulidis
Mila – June 2025
Slides available at djsutherland.ml/slides/entk-mila
(Swipe or arrow keys to move through slides;
m
for a menu to jump;
?
to show more.)PDF version
HTML version
Yi (Joshua) Ren
LLM “post-training” Language models (e.g. GPT-2)Scrape up a ton of the internet (usually illegally) Train a big Transformer for next-token prediction Super-useful as component of lots of things…but not necessarily what we want itself Turning a language model into a chatbot (e.g. ChatGPT):Run “supervised fine-tuning” on a dataset of chatbot-like interactions Run “preference optimization”: given prompt x, say A, not B Surprises in LLM post-training Preference optimization: “given prompt
x
, say
A
, not
B
” Common algorithm: Direct Preference Optimization [RSM+ NeurIPS-23 ] Weird things can happen here! Makes
B
way less likely, but eventually, model almost always says some
C
There are some workarounds, but…why? Learning dynamics Most theoretical analyses in this area ask: what do optimal solutions look like? Turns out the loss function is very underspecified; there are many optimal solutions “Implicit regularization” studies which one GD/SGD/Adam/… eventually converges to“Eventually” can take a really long time We'll take a related, but more qualitative approach What does each step do to the model?Mapping from parameters to “what the model does” is complicated When I take an SGD step to “learn”
f(\exi) \approx \yi
, what happens to my predictions on
\xt
? Learning dynamics(Taylor's version) “Learning dynamics” of a model:
f_t(\xt)
for some fixed
\xt
as params
\w_t
change Suppose
\z = z(x)
,
\f = \sigma(\z)
,
“plain” SGD on
\frac1N \sum_{i=1}^N \mathcal L_t(\exi, \yi)
:
\begin{align*}
\!\!\!\!\!\!\!\!\!\!\!
\underbrace{\f_{t+1}(\xt)}_{k \times 1} - \underbrace{\f_t(\xt)}_{k \times 1}
&= \underbrace{\left( \nabla_\w \f(\xt) \rvert_{\w_t} \right)}_{k \times p}
\,
\underbrace{(\w_{t+1} - \w_t)}_{p \times 1}
{\color{gray}{} + \bigO\left( \norm{\w_{t+1} - \w_t}^2 \right)}
\\&\fragment[3]{ {}
= \underbrace{\left( \nabla_\w \f(\xt) \rvert_{\w_t} \right)}_{k \times p}
\Bigl( - \eta \underbrace{\nabla_\w \mathcal L(\exi, \yi) \rvert_{\w_t}}_{1 \times p} \Bigr)\tp
{\color{gray}{} + \bigO(\eta^2)}
}
\\&\fragment[4]{ {}
= - \eta \,
\underbrace{\bigl( \nabla_\z \f(\xt) \rvert_{\z_t} \bigr)}_{k \times k}
\underbrace{\bigl( \nabla_\w \z(\xt) \rvert_{\w_t} \bigr)}_{k \times p}
\;
\underbrace{ \left( \nabla_\w \z(\exi) \rvert_{\w_t} \right)\tp }_{p \times k} \,
\underbrace{ \left( \nabla_\z \mathcal L(\exi, \yi) \rvert_{\z_t} \right)\tp }_{k \times 1} \,
{\color{gray}{} + \bigO(\eta^2)}
}
\\&\fragment[5]{ {}
= - \eta
\;\;\;\;\;
\mathcal A_t(\xt)
\;\;\;\;\;
\;\;\;\;\;\;\;\;\;\;\;\;\;
\mathcal K_t(\xt, \exi)
\;\;\;\;\;\;\;\;\;\;\;\;\;
\;\;\;\;\;\;\;\;\;\,
\mathcal G_t(\exi, \yi)
\;\;\;\;\;\;\;\,
{\color{gray}{} + \bigO(\eta^2)}
}
\end{align*}
Learning dynamics “Learning dynamics” of a model:
f_t(\xt)
for some fixed
\xt
as params
\w_t
change To start:
\z = h_\theta(x)
,
f = \sigma(\z)
,
“plain” SGD on
\frac1N \sum_{i=1}^N \mathcal L_t(\exi, \yi)
:
\f_{t+1}(\xt) - \f_t(\xt)
= - \eta
\;
\mathcal A_t(\xt)
\;
\mathcal K_t(\xt, \exi)
\;
\mathcal G_t(\exi, \yi)
{\color{gray}{} + \bigO(\eta^2)}
\mathcal G_t(\exi, \yi) = \nabla_\z \mathcal L(\exi, \yi) \rvert_{\z_t}
: how much do I need to change my
\exi
prediction?For square loss with
\sigma(\z) = \z
,
\mathcal G_t = \f_t(\exi) - \yi
: how wrong was I before? For cross-entropy on logits,
\mathcal G_t = \Softmax(\z_t(\exi))_{\yi} - e_{\yi}
: how wrong was I before?
\mathcal A_t(\xt) = \nabla_\z \sigma(\z_t)
just “converts” prediction changesIf
\sigma(\z) = \z
,
\mathcal A_t
is the identity;
if
\sigma = \log \Softmax
,
\mathcal A_t = \mathbf I_k - \mathbf{1}_k \, \pi_t(\xt)\tp
\mathcal K_t(\xt, \exi) = (\nabla_\w \z(\xt)\rvert_{\w_t}) (\nabla_\w \z(\exi)\rvert_{\w_t})\tp
is
k \times k
empirical neural tangent kernel of
\z
If
\exi
,
\xt
are “dissimilar” (small eNTK), stepping on
(\exi, \yi)
barely changes
\xt
prediction If
\exi
,
\xt
are “similar” (large eNTK), makes
\xt
prediction more like
\yi
Example: learning dynamics on MNIST
\log \pi_{t+1}(\xt) - \log \pi_t(\xt) \approx - \eta \mathcal A_t(\xt) \, \mathcal K_t(\xt, \exi) \, \mathcal G_t(\exi, \yi)
But wait…aren't NTKs an unrealistic approximation? A quick aside: the “NTK regime” and infinite limits
Full-batch GD:
“stacking things up”,
\begin{align*}
f_{t+1}(\xt) - f_t(\xt)
&= - \frac{\eta}{N} \sum_{i=1}^N \mathcal A_t(\xt) \eNTK_{t}(\xt, \exi) \mathcal G_t(\exi, \yi)
{\color{gray}{} + \bigO(\eta^2)}
\\&\fragment[1]{ {}
= - \frac{\eta}{N}
\underbrace{\mathcal A_t(\xt)}_{k \times k} \;
\underbrace{\eNTK_t(\xt, \Xs)}_{k \times k N} \;
\underbrace{\mathcal G_t(\Xs, \ys)}_{k N \times 1}
{\color{gray}{} + \bigO(\eta^2)}
}
\end{align*}
Observation I: If
f
is “wide enough”
with any usual architecture+init* [Yang+Litwin 2021 ] ,
\eNTK_t(\cdot, \Xs)
is roughly constant through training For square loss,
\mathcal L(\Xs, \ys) = f_t(\Xs) - \ys
: dynamics agree with kernel regression!
f_t(\xt) \xrightarrow{t \to \infty} \eNTK_{0}(\xt, \Xs) \eNTK_{0}(\Xs, \Xs)^{-1} (\ys - f_0(\Xs)) + f_0(\xt)
Observation II: As
f
becomes “infinitely wide”
with any usual architecture+init* [Yang 2019 ] ,
\eNTK_{0}(x_1, x_2) \xrightarrow{a.s.} \NTK(x_1, x_2)
,
independent of the random
\w_0
Infinite NTKs are great Infinitely-wide neural networks have very simple behaviour !No need to worry about bad local minima, optimization complications, … Understanding “implicit bias” of wide nets
\approx
understanding NTK norm of functions Can compute
\NTK
exactly for many architectures A great kernel for many kernel methods! But (infinite) NTKs aren't “the answer” Computational expense:Poor scaling for large-data problems: typically
n^2
memory and
n^2
to
n^3
computationCIFAR-10 has
n = 50\,000
,
k = 10
: an
nk \times nk
matrix of float64 s is 2 terabytes! ILSVRC2012 has
n \approx 1\,200\,000
,
k = 1\,000
: 11.5 million terabytes (exabytes) For deep/complex models (especially CNNs), each pair very slow / memory-intensive Attention is even harder to handle Practical performance:Typically performs worse than GD for “non-small-data” tasks (MNIST and up) Theoretical limitations:NTK “doesn't do feature learning”: We now know many problems where gradient descent on an NN
\gg
any kernel methodCases where GD error
\to 0
, any kernel is barely better than random [Malach+ 2021 ] What can we learn from empirical NTKs? As a theoretical tool for local understanding: Why DPO breaks Why GRPO does weird stuff + how to fix Fine-grained explanation for early stopping in knowledge distillation How you should fine-tune models As a practical tool for approximating “lookahead” in active learning Plus: efficiently approximating
\eNTK
s for large output dimensions
k
, with guarantees Adapting to the LLM setting First problem: we don't classify a full response at a time, we do it token-by-token Once we've framed it correctly, this is fine:
stack prompt+response into
\sti = [ \exi, \yi ]
Change in
\log \pi(\yt_{\color{cb1}m} \mid \xt, \yt_{:{\color{cb1}m}-1})
based on token-by-token update of
\exi, \yi
is
[\Delta \log \pi_{t}(\yt \mid \stt)]_m
= - \sum_{{\color{cb1}l} = 1}^{\color{cb1} L_i} \eta [\mathcal A_t(\stt)]_{\color{cb2} m} [ \eNTK_t(\stt, \sti) ]_{{\color{cb2}m},\color{cb1}{l}} [ \mathcal G_t(\sti) ]_{\color{cb1}l}
{\color{gray}{} + \bigO(\eta^2)}
Second problem: we can't check all possible output probabilities anymore Workaround: track some informative possible responsesThe dataset responses, rephrases, similar strings with different meanings Irrelevant responses in training set, random sentences… LLM supervised fine-tuning SFT makes dispreferred answers more likely …because they're “similar enough” to the preferred ones
\eNTK
is reasonably large;
\mathcal G
starts big (pulls up), gets small (pulls up less)Ungrammatical responses just go down;
\eNTK
is small, so no upwards pressure Also makes answers to different questions more likely…one form of hallucination? Direct Preference Optimization (DPO)
\mathcal L_t^{\mathrm{DPO}}(\exi, \yip, \yim)
= \log \sigma\left(\beta\left[
\log \frac{\pi_t(\yip \mid \exi)}{\pi_\mathrm{ref}(\yip \mid \exi)}
- \log \frac{\pi_t(\yim \mid \exi)}{\pi_\mathrm{ref}(\yim \mid \exi)}
\right]\right)
which gives that
[\Delta \log \pi_{t}(\yt \mid \stt)]_m
is about
\qquad\mathcal G_t^\mathrm{DPO}(\st) = \beta (1 - \sigma\bigl( \dots \bigr)) \bigl( \pi_t(\y \mid \st) - e_{\y} \bigr)
-\eta [\mathcal A_t(\stt)]_{\color{cb2} m} \Biggl(
\sum_{{\color{cb1}l} = 1}^{\color{cb1} L_i}
[ \eNTK_t(\stt, \stip) ]_{{\color{cb2}m},\color{cb1}{l}}
[ \mathcal G_t^\mathrm{DPO}(\stip) ]_{\color{cb1}l}
-
\sum_{{\color{cb1}l} = 1}^{\color{cb1} L_i}
[ \eNTK_t(\stt, \stim) ]_{{\color{cb2}m},\color{cb1}{l}}
[ \mathcal G_t^\mathrm{DPO}(\stim) ]_{\color{cb1}l}
\Biggr)
This negative gradient can do really weird things: Positive gradients cancel out…in the positive context Squeezing effect accumulates over time What can we learn from empirical NTKs? As a theoretical tool for local understanding: Why DPO breaks Why GRPO does weird stuff + how to fix Fine-grained explanation for early stopping in knowledge distillation How you should fine-tune models As a practical tool for approximating “lookahead” in active learning Plus: efficiently approximating
\eNTK
s for large output dimensions
k
, with guarantees Group Relative Policy Optimization (GRPO) [DeepSeekMath 24 ] Similar to a “group-wise” version of DPO; negative gradients have similar effect! Negative token hidden rewards Estimate which tokens are bad by correlation to tokens in positive responses Down-weight penalties on tokens that are probably okay What can we learn from empirical NTKs? As a theoretical tool for local understanding: Why DPO breaks Why GRPO does weird stuff + how to fix Fine-grained explanation for early stopping in knowledge distillation How you should fine-tune models As a practical tool for approximating “lookahead” in active learning Plus: efficiently approximating
\eNTK
s for large output dimensions
k
, with guarantees Better supervisory signal implies better learning Classification: target is
\mathcal L_P = \E_{(x, y)} \mathcal L(x, y) = \E_x \E_{y \mid x} \ell_y(f(x))
Normally: see
\{(\exi, \yi)\}
, minimize
(
\vec\ell(\hat y) \in \R^k
is vector of losses for all possible labels)
\!\!\!
\mathcal L_{\Xs, \ys}
= \frac1N \sum_{i=1}^N \ell_{\yi}(f(\exi))
\fragment[1]{= \frac1N \sum_{i=1}^N \sum_{c=1}^k \indic(\yi = c) \ell_c(f(\exi)) }
\fragment[2]{= \frac1N \sum_{i=1}^n e_{\yi} \cdot \vec\ell(f(\exi))}
Potentially better scheme:
see
\{(\exi, \ptari)\}
,
minimize
\displaystyle L^\tar(f) = \frac1N \sum_{i=1}^N \ptari \cdot \vec\ell(f(\exi))
Can reduce variance if
\ptari \approx \pstari
, the true conditional probabilities Knowledge distillation Process:Train a teacher
f^\mathit{teacher}
on
\{ (\exi, \yi) \}
with standard ERM,
L(f)
Train a student on
\{ (\exi, {\color{cb4} f^\mathit{teacher}}(\exi)) \}
with
L^\tar
Usually
\color{cb5} f^\mathit{student}
is “smaller” than
\color{cb4} f^\mathit{teacher}
But “self-distillation” (using the same architecture), often
\color{cb5} f^\mathit{student}
outperforms
\color{cb4} f^\mathit{teacher}
! One possible explanation:
\color{cb4} f^\mathit{teacher}(\exi)
is closer to
\pstari
than sampled
\yi
But why would that be? Zig-Zagging behaviour in learning Plots of (three-way) probabilistic predictions:
\boldsymbol\times
shows
\pstari
,
\boldsymbol\times
shows
\yi
eNTK explains it Let
q_t(\xt) = \operatorname{softmax}(f_t(\xt)) \in \R^k
; for cross-entropy loss, one SGD step gives us
q_{t+1}(\xt) - q_t(\xt)
= \eta \; \mathcal A_t(\xt) \, \eNTK_{\w_t}(\xt, \exi) \, (\ptari - q_t(\exi))
{\color{gray}{} + \bigO(\eta^2)}
\mathcal A_t(\xt) = \operatorname{diag}(q_t(\xt)) - q_t(\xt) q_t(\xt)\tp
is the covariance of a
\operatorname{Categorical}(q_t(\xt))
Improves distillation (esp. with noisy labels) to take moving average of
q_t(\exi)
as
\ptari
What can we learn from empirical NTKs? As a theoretical tool for local understanding: Why DPO breaks Why GRPO does weird stuff + how to fix Fine-grained explanation for early stopping in knowledge distillation How you should fine-tune models As a practical tool for approximating “lookahead” in active learning Plus: efficiently approximating
\eNTK
s for large output dimensions
k
, with guarantees Fine-tuning Pretrain, re-initialize a random head, then adapt to a downstream task.
Two phases:Head probing : only update the head
g(z)
Fine-tuning : update head
g(z)
and backbone
z = f(x)
together If we only fine-tune: noise from random head might break our features! If we head-probe to convergence: might already fit training data and not change features! How much do we change our features? Same kind of decomposition with backbone features
z = f(x)
, head
q = \operatorname{softmax}(g(z))
:
z_{t+1}(\xt) - z_t(\xt)
= \frac{\eta}{N} \sum_{i=1}^N
\underbrace{\eNTK_{\w_t}^{f}(\xt, \exi)}_\text{eNTK of backbone}
\, \underbrace{\left(\nabla_z q_t(\exi)\right)\tp_\phantom{\w_t}\!\!\!}_\text{direction of head}
\, \underbrace{(e_{\yi} - q_t(\exi))_\phantom{\w_t}\!\!\!\!\!\!\!}_\text{“energy"}
{\color{gray}{} + \bigO(\eta^2)}
If initial “energy”, e.g.
\E_{\exi,\yi} \norm{e_{\yi} - p_0(\exi)}
, is small, features don't change much If we didn't do any head probing, “direction” is very random, especially if
g
is rich Specializing to simple linear-linear model, can get insights about trends in
z
Recommendations from paper:Early stop during head probing (ideally, try multiple lengths for downstream task) Label smoothing can help; so can more complex heads, but be careful With random head (no head probing), generalized cross-validation on eNTK model gives excellent estimate of downstream loss What can we learn from empirical NTKs? As a theoretical tool for local understanding: Why DPO breaks Why GRPO does weird stuff + how to fix Fine-grained explanation for early stopping in knowledge distillation How you should fine-tune models As a practical tool for approximating “lookahead” in active learning Plus: efficiently approximating
\eNTK
s for large output dimensions
k
, with guarantees Pool-based active learning Pool of unlabeled data available; ask for annotations of the “most informative” points Kinds of acquisition functions used for deep active learning:Uncertainty-based: maximum entropy, BALD Representation-based: BADGE, LL4AL Another kind used for simpler models: lookahead criteria“How much would my model change if I saw
\exi
with label
\yi
?” Too expensive for deep learning…unless you use a local approximation to retraining Approximate retraining with local linearization Given
f_{\mathcal L}
trained on labeled data
\mathcal L
, approximate
f_{\mathcal L \cup \{(\exi, \yi)\}}
with local linearization
\!\!\!\!\!
f_{\mathcal L \cup \{(\exi, \yi)\}}(\xt)
\approx
f_{\mathcal L}(\xt) +
\eNTK_{\w_{\mathcal L}}\left(\xt, \begin{bmatrix} \Xs_{\mathcal L} \\ \exi \end{bmatrix} \right)
\eNTK_{\w_{\mathcal L}}\left(\begin{bmatrix} \Xs_{\mathcal L} \\ \exi \end{bmatrix}, \begin{bmatrix} \Xs_{\mathcal L} \\ \exi \end{bmatrix} \right)^{-1}
\left( \begin{bmatrix} \ys_{\mathcal L} \\ \yi \end{bmatrix} - f_{\mathcal L}\left(\begin{bmatrix} \Xs_{\mathcal L} \\ \exi \end{bmatrix} \right) \right)
Rank-one updates for efficient computation: schema We prove this is exact for infinitely wide networks
f_0 \to f_{\mathcal L} \to f_{\mathcal L \cup \{(\exi, \yi)\}}
agrees with direct
f_0 \to f_{\mathcal L \cup \{(\exi, \yi)\}}
Local approximation with eNTK “should” work much more broadly than “NTK regime”Much faster than SGD Much more effective than infinite NTK and one-step SGD Matches/beats state of the art Downside: usually more computationally expensive (especially memory)
Enables new interaction modes “Sequential” querying: incorporate true new labels one-at-a-time instead of batchOnly update
\eNTK
occasionally Makes sense when labels cost $ but are fast; other deep AL methods need to retrain What can we learn from empirical NTKs? As a theoretical tool for local understanding: Why DPO breaks Why GRPO does weird stuff + how to fix Fine-grained explanation for early stopping in knowledge distillation How you should fine-tune models As a practical tool for approximating “lookahead” in active learning Plus: efficiently approximating
\eNTK
s for large output dimensions
k
, with guarantees Approximating empirical NTKs I hid something from you on active learning (and Wei/Hu/Steinhardt fine-tuning) results… With
k
classes,
\eNTK(\Xs, \Xs) \in \R^{k N \times k N}
– potentially very big But actually, we know that
\E_{\w} \eNTK_\w(x_1, x_2)
is diagonal for most architecturesLet
\displaystyle
\pNTK_\w(x_1, x_2) =
\underbrace{\left[ \nabla_\w f_1(x_1) \right] }_{1 \times p}
\underbrace{\left[ \nabla_\w f_1(x_2) \right) \tp}_{p \times 1}
% \underbrace{\left[ \nabla_\w\left( \frac{1}{\sqrt k} \sum_{j=1}^k f_j(x_1) \right) \right]}_{1 \times p}
% \underbrace{\left[ \nabla_\w\left( \frac{1}{\sqrt k} \sum_{j=1}^k f_j(x_2) \right) \right]\tp}_{p \times 1}
.
\E_\w \eNTK_\w(x_1, x_2) = \E_w\left[ \pNTK_\w(x_1, x_2) \right] I_k
.
\pNTK(\Xs, \Xs) \in \R^{N \times N}
(no
k
!) Can also use “sum of logits ”
\frac{1}{\sqrt k} \sum_{j=1}^k f_j
instead of just “first logit ”
f_1
Lots of work (including above) has used
\pNTK
instead of
\eNTK
Often without saying anything; sometimes doesn't seem like they know they're doing it Can we justify this more rigorously? pNTK motivation Say
f(x) = V \phi(x)
,
\phi(x) \in \R^h
, and
V \in \R^{k \times h}
has rows
v_j \in \R^h
with iid entries If
v_{j,i} \sim \mathcal N(0, \sigma^2)
, then
v_1
and
\frac{1}{\sqrt k} \sum_{j=1}^k v_j
have same distribution
\begin{align*}
\fragment[1]{
\eNTK_\w(x_1, x_2)_{jj'}
}&\fragment[1]{ {}
= v_j\tp \eNTK^\phi_{\w \setminus V}(x_1, x_2) \, v_{j'} + \indic(j = j') \phi(x_1)\tp \phi(x_2)
}
\\
\fragment[2]{\pNTK_\w(x_1, x_2)}
&\fragment[2]{ {}
= {\color{cb4} v_1\tp} \eNTK^\phi_{\w \setminus V}(x_1, x_2) \, {\color{cb4} v_1} + \phi(x_1)\tp \phi(x_2)
}
\end{align*}
We want to bound difference
\eNTK(x_1, x_2) - \pNTK(x_1, x_2) I_k
Want
v_1\tp A v_1
and
v_j\tp A v_j
to be close, and
v_j\tp A v_{j'}
small, for random
v
and fixed
A
Using Hanson-Wright :
\displaystyle \frac{\norm{\eNTK - \pNTK I}_F}{\norm{\eNTK}_F} \le \frac{\norm{\eNTK^\phi}_F + 4 \sqrt h}{\Tr(\eNTK^\phi)} k \log \frac{2 k^2}{\delta}
Fully-connected ReLU nets at init., fan-in mode: numerator
\bigO(h \sqrt h)
, denom
\Theta(h^2)
pNTK's Frobenius error Same kind of theorem / empirical results for largest eigenvalue, and empirical results for
\lambda_\min
, condition number
Kernel regression with pNTK Reshape things to handle prediction appropriately:
\begin{align*}
\underbrace{f_{\eNTK}(\xt)}_{k \times 1} &=
\underbrace{f_0(\xt)}_{k \times 1}
+ \phantom{\Big(}
\underbrace{\eNTK_{\w_0}(\xt, \Xs)}_{k \times k N}
\,\underbrace{\eNTK_{\w_0}(\Xs, \Xs)^{-1}}_{k N \times k N}
\,\underbrace{(\ys - f_0(\Xs))}_{kN \times 1}
\\
\underbrace{f_{\pNTK}(\xt)}_{k \times 1} &=
\underbrace{f_0(\xt)}_{k \times 1}
+ \Big(
\underbrace{\pNTK_{\w_0}(\xt, \Xs)}_{1 \times N}
\,\underbrace{\pNTK_{\w_0}(\Xs, \Xs)^{-1}}_{N \times N}
\,\underbrace{(\ys - f_0(\Xs))}_{N \times k}
\Big)\tp
\end{align*}
We have
\norm{f_{\eNTK}(\xt) - f_{\pNTK}(\xt)} = \bigO(\frac{1}{\sqrt h})
againIf we add regularization, need to “scale”
\lambda
between the two Kernel regression with pNTK pNTK speed-up pNTK speed-up on active learning task pNTK for full CIFAR-10 regression
\eNTK(\Xs, \Xs)
on CIFAR-10: 1.8 terabytes of memory
\pNTK(\Xs, \Xs)
on CIFAR-10: 18 gigabytes of memoryWorse than infinite NTK for FCN/ConvNet (where they can be computed, if you try hard) Way worse than SGD Recap eNTK is a good tool for intuitive understanding of the learning process
Deng, Ren, M. Li, S., X. Li, Thrampoulidis
eNTK is practically very effective at “lookahead” for active learning
You should probably use pNTK instead of eNTK for high-dim output problems: