\definecolor{cb1}{RGB}{76,114,176} \definecolor{cb2}{RGB}{221,132,82} \definecolor{cb3}{RGB}{85,168,104} \definecolor{cb4}{RGB}{196,78,82} \definecolor{cb5}{RGB}{129,114,179} \definecolor{cb6}{RGB}{147,120,96} \definecolor{cb7}{RGB}{218,139,195} \newcommand{\abs}[1]{\left\lvert #1 \right\rvert} \newcommand{\norm}[1]{\left\lVert #1 \right\rVert} \DeclareMathOperator*{\argmin}{argmin} \DeclareMathOperator{\bigO}{\mathcal{O}} \DeclareMathOperator{\Cov}{Cov} \DeclareMathOperator{\E}{\mathbb{E}} \newcommand{\indic}{\mathbb{I}} \newcommand{\R}{\mathbb{R}} \DeclareMathOperator{\Softmax}{Softmax} \newcommand{\tp}{^\mathsf{T}} \DeclareMathOperator{\Tr}{Tr} \DeclareMathOperator{\Var}{Var} \DeclareMathOperator{\NTK}{NTK} \DeclareMathOperator{\eNTK}{\mathcal K} \DeclareMathOperator{\pNTK}{pNTK} \newcommand{\lin}{\mathit{lin}} \newcommand{\xt}{{\color{cb2} \tilde x}} \newcommand{\yt}{{\color{cb2} \tilde y}} \newcommand{\exi}{{\color{cb1} x_i}} \newcommand{\yi}{{\color{cb1} y_i}} \newcommand{\yip}{{\color{cb1} y_i^+}} \newcommand{\yim}{{\color{cb1} y_i^-}} \newcommand{\Xs}{{\color{cb1} \mathbf X}} \newcommand{\ys}{{\color{cb1} \mathbf y}} \newcommand{\f}{\mathbf{f}} \newcommand{\w}{\mathbf{w}} \newcommand{\y}{\mathbf{y}} \newcommand{\z}{\mathbf{z}} \newcommand{\st}{\boldsymbol{\chi}} \newcommand{\sti}{{\color{cb1} \chi_i}} \newcommand{\stip}{{\color{cb1} \chi_i^+}} \newcommand{\stim}{{\color{cb1} \chi_i^-}} \newcommand{\stt}{{\color{cb2} \tilde\chi}} \newcommand{\tar}{\mathit{tar}} \newcommand{\ptari}{{\color{cb1} p^\tar_i}} \newcommand{\pstari}{{\color{cb1} p^*_i}}

Local Learning Dynamics
Help Explain (Post-)Training Behaviour

Danica J. Sutherland(she)
University of British Columbia (UBC) / Alberta Machine Intelligence Institute (Amii)
Knowl. dist. analysis
[ICLR-22]
Yi Ren
Shangmin Guo
Finetuning analysis
[ICLR-23]
Yi Ren
Shangmin Guo
Wonho Bae
Active learning
[NeurIPS-22]
M. Amin Mohamadi
Wonho Bae
Pseudo-NTK
[ICML-23]
M. Amin Mohamadi
Wonho Bae
DPO/etc analysis
[ICLR-25]
Yi Ren
GRPO analysis+fix
[arXiv]
Wenlong Deng
Yi Ren
Muchen Li
Xiaoxiao Li
Christos Thrampoulidis

Penn – June 2025

Slides available at djsutherland.ml/slides/entk-penn

(Swipe or arrow keys to move through slides; m for a menu to jump; ? to show more.)
PDF version

HTML version

Yi (Joshua) Ren

LLM “post-training”

  • Language models (e.g. GPT-2)
    • Scrape up a ton of the internet (usually illegally)
    • Train a big Transformer for next-token prediction
    • Super-useful as component of lots of things…but not necessarily what we want itself
  • Turning a language model into a chatbot (e.g. ChatGPT):
    • Run “supervised fine-tuning” on a dataset of chatbot-like interactions
    • Run “preference optimization”: given prompt x, say A, not B

Surprises in LLM post-training

  • Preference optimization: “given prompt x , say A , not B
  • Common algorithm: Direct Preference Optimization [RSM+ NeurIPS-23]
  • Weird things can happen here!
  • Even in the best case, “too much” DPO hurts [RHPF CoLM-24]
  • Makes B way less likely, but eventually, model almost always says some C
  • There are some workarounds, but…why?

Learning dynamics

  • Most theoretical analyses in this area ask: what do optimal solutions look like?
    • Turns out the loss function is very underspecified; there are many optimal solutions
    • “Implicit regularization” studies which one GD/SGD/Adam/… eventually converges to
      • “Eventually” can take a really long time
  • We'll take a related, but more qualitative approach
  • What does each step do to the model?
    • Mapping from parameters to “what the model does” is complicated
    • When I take an SGD step to “learn” f(\exi) \approx \yi ,
      what happens to my predictions on \xt ?
      • Also been called “local elasticity” [HS ICLR-20]

Learning dynamics(Taylor's version)

  • “Learning dynamics” of a model: f_t(\xt) for some fixed \xt as params \w_t change
  • Suppose \z = z(x) , \f = \sigma(\z) , “plain” SGD on \frac1N \sum_{i=1}^N \mathcal L_t(\exi, \yi) :
  • \begin{align*} \!\!\!\!\!\!\!\!\!\!\! \underbrace{\f_{t+1}(\xt)}_{k \times 1} - \underbrace{\f_t(\xt)}_{k \times 1} &= \underbrace{\left( \nabla_\w \f(\xt) \rvert_{\w_t} \right)}_{k \times p} \, \underbrace{(\w_{t+1} - \w_t)}_{p \times 1} {\color{gray}{} + \bigO\left( \norm{\w_{t+1} - \w_t}^2 \right)} \\&\fragment[3]{ {} = \underbrace{\left( \nabla_\w \f(\xt) \rvert_{\w_t} \right)}_{k \times p} \Bigl( - \eta \underbrace{\nabla_\w \mathcal L(\exi, \yi) \rvert_{\w_t}}_{1 \times p} \Bigr)\tp {\color{gray}{} + \bigO(\eta^2)} } \\&\fragment[4]{ {} = - \eta \, \underbrace{\bigl( \nabla_\z \f(\xt) \rvert_{\z_t} \bigr)}_{k \times k} \underbrace{\bigl( \nabla_\w \z(\xt) \rvert_{\w_t} \bigr)}_{k \times p} \; \underbrace{ \left( \nabla_\w \z(\exi) \rvert_{\w_t} \right)\tp }_{p \times k} \, \underbrace{ \left( \nabla_\z \mathcal L(\exi, \yi) \rvert_{\z_t} \right)\tp }_{k \times 1} \, {\color{gray}{} + \bigO(\eta^2)} } \\&\fragment[5]{ {} = - \eta \;\;\;\;\; \mathcal A_t(\xt) \;\;\;\;\; \;\;\;\;\;\;\;\;\;\;\;\;\; \mathcal K_t(\xt, \exi) \;\;\;\;\;\;\;\;\;\;\;\;\; \;\;\;\;\;\;\;\;\;\, \mathcal G_t(\exi, \yi) \;\;\;\;\;\;\;\, {\color{gray}{} + \bigO(\eta^2)} } \end{align*}

Learning dynamics

  • “Learning dynamics” of a model: f_t(\xt) for some fixed \xt as params \w_t change
  • To start: \z = h_\theta(x) , f = \sigma(\z) , “plain” SGD on \frac1N \sum_{i=1}^N \mathcal L_t(\exi, \yi) :
  • \f_{t+1}(\xt) - \f_t(\xt) = - \eta \; \mathcal A_t(\xt) \; \mathcal K_t(\xt, \exi) \; \mathcal G_t(\exi, \yi) {\color{gray}{} + \bigO(\eta^2)}

  • \mathcal G_t(\exi, \yi) = \nabla_\z \mathcal L(\exi, \yi) \rvert_{\z_t} : how much do I need to change my \exi prediction?
    • For square loss with \sigma(\z) = \z , \mathcal G_t = \f_t(\exi) - \yi : how wrong was I before?
    • For cross-entropy on logits, \mathcal G_t = \Softmax(\z_t(\exi))_{\yi} - e_{\yi} : how wrong was I before?
  • \mathcal A_t(\xt) = \nabla_\z \sigma(\z_t) just “converts” prediction changes
    • If \sigma(\z) = \z , \mathcal A_t is the identity; if \sigma = \log \Softmax , \mathcal A_t = \mathbf I_k - \mathbf{1}_k \, \pi_t(\xt)\tp
  • \mathcal K_t(\xt, \exi) = (\nabla_\w \z(\xt)\rvert_{\w_t}) (\nabla_\w \z(\exi)\rvert_{\w_t})\tp is k \times k empirical neural tangent kernel of \z
    • If \exi , \xt are “dissimilar” (small eNTK), stepping on (\exi, \yi) barely changes \xt prediction
    • If \exi , \xt are “similar” (large eNTK), makes \xt prediction more like \yi

Example: learning dynamics on MNIST

\log \pi_{t+1}(\xt) - \log \pi_t(\xt) \approx - \eta \mathcal A_t(\xt) \, \mathcal K_t(\xt, \exi) \, \mathcal G_t(\exi, \yi)

But wait…aren't NTKs an unrealistic approximation?

A quick aside: the “NTK regime” and infinite limits

  • Full-batch GD: “stacking things up”, \begin{align*} f_{t+1}(\xt) - f_t(\xt) &= - \frac{\eta}{N} \sum_{i=1}^N \mathcal A_t(\xt) \eNTK_{t}(\xt, \exi) \mathcal G_t(\exi, \yi) {\color{gray}{} + \bigO(\eta^2)} \\&\fragment[1]{ {} = - \frac{\eta}{N} \underbrace{\mathcal A_t(\xt)}_{k \times k} \; \underbrace{\eNTK_t(\xt, \Xs)}_{k \times k N} \; \underbrace{\mathcal G_t(\Xs, \ys)}_{k N \times 1} {\color{gray}{} + \bigO(\eta^2)} } \end{align*}
  • Observation I: If f is “wide enough” with any usual architecture+init* [Yang+Litwin 2021], \eNTK_t(\cdot, \Xs) is roughly constant through training
    • For square loss, \mathcal L(\Xs, \ys) = f_t(\Xs) - \ys : dynamics agree with kernel regression!
    • f_t(\xt) \xrightarrow{t \to \infty} \eNTK_{0}(\xt, \Xs) \eNTK_{0}(\Xs, \Xs)^{-1} (\ys - f_0(\Xs)) + f_0(\xt)
  • Observation II: As f becomes “infinitely wide” with any usual architecture+init* [Yang 2019], \eNTK_{0}(x_1, x_2) \xrightarrow{a.s.} \NTK(x_1, x_2) , independent of the random \w_0

Infinite NTKs are great

  • Infinitely-wide neural networks have very simple behaviour!
    • No need to worry about bad local minima, optimization complications, …
    • Understanding “implicit bias” of wide nets \approx understanding NTK norm of functions
  • Can compute \NTK exactly for many architectures
  • A great kernel for many kernel methods!

But (infinite) NTKs aren't “the answer”

  • Computational expense:
    • Poor scaling for large-data problems: typically n^2 memory and n^2 to n^3 computation
      • CIFAR-10 has n = 50\,000 , k = 10 : an nk \times nk matrix of float64s is 2 terabytes!
      • ILSVRC2012 has n \approx 1\,200\,000 , k = 1\,000 : 11.5 million terabytes (exabytes)
    • For deep/complex models (especially CNNs), each pair very slow / memory-intensive
    • Attention is even harder to handle
  • Practical performance:
    • Typically performs worse than GD for “non-small-data” tasks (MNIST and up)
  • Theoretical limitations:
    • NTK “doesn't do feature learning”:
    • We now know many problems where gradient descent on an NN \gg any kernel method
      • Cases where GD error \to 0 , any kernel is barely better than random [Malach+ 2021]

What can we learn from empirical NTKs?

  • As a theoretical tool for local understanding:
    • Why DPO breaks
    • Why GRPO does weird stuff + how to fix
    • Fine-grained explanation for early stopping in knowledge distillation
    • How you should fine-tune models
  • As a practical tool for approximating “lookahead” in active learning
  • Plus: efficiently approximating \eNTK s for large output dimensions k , with guarantees

Adapting to the LLM setting

  • First problem: we don't classify a full response at a time, we do it token-by-token
  • Once we've framed it correctly, this is fine: stack prompt+response into \sti = [ \exi, \yi ]
  • Change in \log \pi(\yt_{\color{cb1}m} \mid \xt, \yt_{:{\color{cb1}m}-1}) based on token-by-token update of \exi, \yi is [\Delta \log \pi_{t}(\yt \mid \stt)]_m = - \sum_{{\color{cb1}l} = 1}^{\color{cb1} L_i} \eta [\mathcal A_t(\stt)]_{\color{cb2} m} [ \eNTK_t(\stt, \sti) ]_{{\color{cb2}m},\color{cb1}{l}} [ \mathcal G_t(\sti) ]_{\color{cb1}l} {\color{gray}{} + \bigO(\eta^2)}

  • Second problem: we can't check all possible output probabilities anymore
  • Workaround: track some informative possible responses
    • The dataset responses, rephrases, similar strings with different meanings
    • Irrelevant responses in training set, random sentences…

LLM supervised fine-tuning

  • SFT makes dispreferred answers more likely
  • …because they're “similar enough” to the preferred ones
  • \eNTK is reasonably large; \mathcal G starts big (pulls up), gets small (pulls up less)
  • Ungrammatical responses just go down; \eNTK is small, so no upwards pressure
  • Also makes answers to different questions more likely…one form of hallucination?

Direct Preference Optimization (DPO)

\mathcal L_t^{\mathrm{DPO}}(\exi, \yip, \yim) = \log \sigma\left(\beta\left[ \log \frac{\pi_t(\yip \mid \exi)}{\pi_\mathrm{ref}(\yip \mid \exi)} - \log \frac{\pi_t(\yim \mid \exi)}{\pi_\mathrm{ref}(\yim \mid \exi)} \right]\right)
which gives that [\Delta \log \pi_{t}(\yt \mid \stt)]_m is about \qquad\mathcal G_t^\mathrm{DPO}(\st) = \beta (1 - \sigma\bigl( \dots \bigr)) \bigl( \pi_t(\y \mid \st) - e_{\y} \bigr) -\eta [\mathcal A_t(\stt)]_{\color{cb2} m} \Biggl( \sum_{{\color{cb1}l} = 1}^{\color{cb1} L_i} [ \eNTK_t(\stt, \stip) ]_{{\color{cb2}m},\color{cb1}{l}} [ \mathcal G_t^\mathrm{DPO}(\stip) ]_{\color{cb1}l} - \sum_{{\color{cb1}l} = 1}^{\color{cb1} L_i} [ \eNTK_t(\stt, \stim) ]_{{\color{cb2}m},\color{cb1}{l}} [ \mathcal G_t^\mathrm{DPO}(\stim) ]_{\color{cb1}l} \Biggr)
This negative gradient can do really weird things:

Negative gradients and the squeezing effect

\pi(y \mid x) = \frac{\exp(z(x)_y)}{\exp(z(x)_y) + \exp(z(x)_{y^*}) + \dots}

  • To decrease \log \pi((\yim)_m \mid [\stim]_{:m}) , decrease numerator and increase denominator
  • If z(x)_{y^*} is big, dominates the sum: increasing it is almost as effective as decreasing z(x)_{y}

Positive gradients cancel out…in the positive context

Squeezing effect accumulates over time

What can we learn from empirical NTKs?

  • As a theoretical tool for local understanding:
    • Why DPO breaks
    • Why GRPO does weird stuff + how to fix
    • Fine-grained explanation for early stopping in knowledge distillation
    • How you should fine-tune models
  • As a practical tool for approximating “lookahead” in active learning
  • Plus: efficiently approximating \eNTK s for large output dimensions k , with guarantees

Group Relative Policy Optimization (GRPO) [DeepSeekMath 24]

  • Similar to a “group-wise” version of DPO; negative gradients have similar effect!

Negative token hidden rewards

  • Estimate which tokens are bad by correlation to tokens in positive responses

Down-weight penalties on tokens that are probably okay

What can we learn from empirical NTKs?

  • As a theoretical tool for local understanding:
    • Why DPO breaks
    • Why GRPO does weird stuff + how to fix
    • Fine-grained explanation for early stopping in knowledge distillation
    • How you should fine-tune models
  • As a practical tool for approximating “lookahead” in active learning
  • Plus: efficiently approximating \eNTK s for large output dimensions k , with guarantees

Better supervisory signal implies better learning

  • Classification: target is \mathcal L_P = \E_{(x, y)} \mathcal L(x, y) = \E_x \E_{y \mid x} \ell_y(f(x))
  • Normally: see \{(\exi, \yi)\} , minimize ( \vec\ell(\hat y) \in \R^k is vector of losses for all possible labels) \!\!\! \mathcal L_{\Xs, \ys} = \frac1N \sum_{i=1}^N \ell_{\yi}(f(\exi)) \fragment[1]{= \frac1N \sum_{i=1}^N \sum_{c=1}^k \indic(\yi = c) \ell_c(f(\exi)) } \fragment[2]{= \frac1N \sum_{i=1}^n e_{\yi} \cdot \vec\ell(f(\exi))}
  • Potentially better scheme: see \{(\exi, \ptari)\} , minimize \displaystyle L^\tar(f) = \frac1N \sum_{i=1}^N \ptari \cdot \vec\ell(f(\exi))
    • Can reduce variance if \ptari \approx \pstari , the true conditional probabilities

Knowledge distillation

  • Process:
    • Train a teacher f^\mathit{teacher} on \{ (\exi, \yi) \} with standard ERM, L(f)
    • Train a student on \{ (\exi, {\color{cb4} f^\mathit{teacher}}(\exi)) \} with L^\tar
  • Usually \color{cb5} f^\mathit{student} is “smaller” than \color{cb4} f^\mathit{teacher}
  • But “self-distillation” (using the same architecture), often \color{cb5} f^\mathit{student} outperforms \color{cb4} f^\mathit{teacher} !
  • One possible explanation: \color{cb4} f^\mathit{teacher}(\exi) is closer to \pstari than sampled \yi
  • But why would that be?

Zig-Zagging behaviour in learning

Plots of (three-way) probabilistic predictions: \boldsymbol\times shows \pstari , \boldsymbol\times shows \yi

eNTK explains it

  • Let q_t(\xt) = \operatorname{softmax}(f_t(\xt)) \in \R^k ; for cross-entropy loss, one SGD step gives us q_{t+1}(\xt) - q_t(\xt) = \eta \; \mathcal A_t(\xt) \, \eNTK_{\w_t}(\xt, \exi) \, (\ptari - q_t(\exi)) {\color{gray}{} + \bigO(\eta^2)} \mathcal A_t(\xt) = \operatorname{diag}(q_t(\xt)) - q_t(\xt) q_t(\xt)\tp is the covariance of a \operatorname{Categorical}(q_t(\xt))
  • Improves distillation (esp. with noisy labels) to take moving average of q_t(\exi) as \ptari

What can we learn from empirical NTKs?

  • As a theoretical tool for local understanding:
    • Why DPO breaks
    • Why GRPO does weird stuff + how to fix
    • Fine-grained explanation for early stopping in knowledge distillation
    • How you should fine-tune models
  • As a practical tool for approximating “lookahead” in active learning
  • Plus: efficiently approximating \eNTK s for large output dimensions k , with guarantees

Fine-tuning

  • Pretrain, re-initialize a random head, then adapt to a downstream task. Two phases:
    • Head probing: only update the head g(z)
    • Fine-tuning: update head g(z) and backbone z = f(x) together
  • If we only fine-tune: noise from random head might break our features!
  • If we head-probe to convergence: might already fit training data and not change features!

How much do we change our features?

  • Same kind of decomposition with backbone features z = f(x) , head q = \operatorname{softmax}(g(z)) : z_{t+1}(\xt) - z_t(\xt) = \frac{\eta}{N} \sum_{i=1}^N \underbrace{\eNTK_{\w_t}^{f}(\xt, \exi)}_\text{eNTK of backbone} \, \underbrace{\left(\nabla_z q_t(\exi)\right)\tp_\phantom{\w_t}\!\!\!}_\text{direction of head} \, \underbrace{(e_{\yi} - q_t(\exi))_\phantom{\w_t}\!\!\!\!\!\!\!}_\text{“energy"} {\color{gray}{} + \bigO(\eta^2)}
  • If initial “energy”, e.g. \E_{\exi,\yi} \norm{e_{\yi} - p_0(\exi)} , is small, features don't change much
  • If we didn't do any head probing, “direction” is very random, especially if g is rich
  • Specializing to simple linear-linear model, can get insights about trends in z
  • Recommendations from paper:
    • Early stop during head probing (ideally, try multiple lengths for downstream task)
    • Label smoothing can help; so can more complex heads, but be careful

How good will our fine-tuned features be? [Wei/Hu/Steinhardt 2022]

  • With random head (no head probing),
    generalized cross-validation on eNTK model gives excellent estimate of downstream loss

What can we learn from empirical NTKs?

  • As a theoretical tool for local understanding:
    • Why DPO breaks
    • Why GRPO does weird stuff + how to fix
    • Fine-grained explanation for early stopping in knowledge distillation
    • How you should fine-tune models
  • As a practical tool for approximating “lookahead” in active learning
  • Plus: efficiently approximating \eNTK s for large output dimensions k , with guarantees

Pool-based active learning

  • Pool of unlabeled data available; ask for annotations of the “most informative” points
  • Kinds of acquisition functions used for deep active learning:
    • Uncertainty-based: maximum entropy, BALD
    • Representation-based: BADGE, LL4AL
  • Another kind used for simpler models: lookahead criteria
    • “How much would my model change if I saw \exi with label \yi ?”
    • Too expensive for deep learning…unless you use a local approximation to retraining

Approximate retraining with local linearization

  • Given f_{\mathcal L} trained on labeled data \mathcal L , approximate f_{\mathcal L \cup \{(\exi, \yi)\}} with local linearization \!\!\!\!\! f_{\mathcal L \cup \{(\exi, \yi)\}}(\xt) \approx f_{\mathcal L}(\xt) + \eNTK_{\w_{\mathcal L}}\left(\xt, \begin{bmatrix} \Xs_{\mathcal L} \\ \exi \end{bmatrix} \right) \eNTK_{\w_{\mathcal L}}\left(\begin{bmatrix} \Xs_{\mathcal L} \\ \exi \end{bmatrix}, \begin{bmatrix} \Xs_{\mathcal L} \\ \exi \end{bmatrix} \right)^{-1} \left( \begin{bmatrix} \ys_{\mathcal L} \\ \yi \end{bmatrix} - f_{\mathcal L}\left(\begin{bmatrix} \Xs_{\mathcal L} \\ \exi \end{bmatrix} \right) \right)
    • Rank-one updates for efficient computation: schema
  • We prove this is exact for infinitely wide networks
    • f_0 \to f_{\mathcal L} \to f_{\mathcal L \cup \{(\exi, \yi)\}} agrees with direct f_0 \to f_{\mathcal L \cup \{(\exi, \yi)\}}
  • Local approximation with eNTK “should” work much more broadly than “NTK regime”

Much faster than SGD

Much more effective than infinite NTK and one-step SGD

Matches/beats state of the art

Downside: usually more computationally expensive (especially memory)

Enables new interaction modes

  • “Sequential” querying: incorporate true new labels one-at-a-time instead of batch
    • Only update \eNTK occasionally
    • Makes sense when labels cost $ but are fast; other deep AL methods need to retrain

What can we learn from empirical NTKs?

  • As a theoretical tool for local understanding:
    • Why DPO breaks
    • Why GRPO does weird stuff + how to fix
    • Fine-grained explanation for early stopping in knowledge distillation
    • How you should fine-tune models
  • As a practical tool for approximating “lookahead” in active learning
  • Plus: efficiently approximating \eNTK s for large output dimensions k , with guarantees

Approximating empirical NTKs

  • I hid something from you on active learning (and Wei/Hu/Steinhardt fine-tuning) results…
  • With k classes, \eNTK(\Xs, \Xs) \in \R^{k N \times k N} – potentially very big
  • But actually, we know that \E_{\w} \eNTK_\w(x_1, x_2) is diagonal for most architectures
    • Let \displaystyle \pNTK_\w(x_1, x_2) = \underbrace{\left[ \nabla_\w f_1(x_1) \right] }_{1 \times p} \underbrace{\left[ \nabla_\w f_1(x_2) \right) \tp}_{p \times 1} % \underbrace{\left[ \nabla_\w\left( \frac{1}{\sqrt k} \sum_{j=1}^k f_j(x_1) \right) \right]}_{1 \times p} % \underbrace{\left[ \nabla_\w\left( \frac{1}{\sqrt k} \sum_{j=1}^k f_j(x_2) \right) \right]\tp}_{p \times 1} . \E_\w \eNTK_\w(x_1, x_2) = \E_w\left[ \pNTK_\w(x_1, x_2) \right] I_k .       \pNTK(\Xs, \Xs) \in \R^{N \times N} (no k !)
    • Can also use “sum of logits \frac{1}{\sqrt k} \sum_{j=1}^k f_j instead of just “first logit f_1
  • Lots of work (including above) has used \pNTK instead of \eNTK
    • Often without saying anything; sometimes doesn't seem like they know they're doing it
  • Can we justify this more rigorously?

pNTK motivation

  • Say f(x) = V \phi(x) , \phi(x) \in \R^h , and V \in \R^{k \times h} has rows v_j \in \R^h with iid entries
  • If v_{j,i} \sim \mathcal N(0, \sigma^2) , then v_1 and \frac{1}{\sqrt k} \sum_{j=1}^k v_j have same distribution \begin{align*} \fragment[1]{ \eNTK_\w(x_1, x_2)_{jj'} }&\fragment[1]{ {} = v_j\tp \eNTK^\phi_{\w \setminus V}(x_1, x_2) \, v_{j'} + \indic(j = j') \phi(x_1)\tp \phi(x_2) } \\ \fragment[2]{\pNTK_\w(x_1, x_2)} &\fragment[2]{ {} = {\color{cb4} v_1\tp} \eNTK^\phi_{\w \setminus V}(x_1, x_2) \, {\color{cb4} v_1} + \phi(x_1)\tp \phi(x_2) } \end{align*}
  • We want to bound difference \eNTK(x_1, x_2) - \pNTK(x_1, x_2) I_k
    • Want v_1\tp A v_1 and v_j\tp A v_j to be close, and v_j\tp A v_{j'} small, for random v and fixed A
    • Using Hanson-Wright: \displaystyle \frac{\norm{\eNTK - \pNTK I}_F}{\norm{\eNTK}_F} \le \frac{\norm{\eNTK^\phi}_F + 4 \sqrt h}{\Tr(\eNTK^\phi)} k \log \frac{2 k^2}{\delta}
    • Fully-connected ReLU nets at init., fan-in mode: numerator \bigO(h \sqrt h) , denom \Theta(h^2)

pNTK's Frobenius error

Same kind of theorem / empirical results for largest eigenvalue,
and empirical results for \lambda_\min , condition number

Kernel regression with pNTK

  • Reshape things to handle prediction appropriately: \begin{align*} \underbrace{f_{\eNTK}(\xt)}_{k \times 1} &= \underbrace{f_0(\xt)}_{k \times 1} + \phantom{\Big(} \underbrace{\eNTK_{\w_0}(\xt, \Xs)}_{k \times k N} \,\underbrace{\eNTK_{\w_0}(\Xs, \Xs)^{-1}}_{k N \times k N} \,\underbrace{(\ys - f_0(\Xs))}_{kN \times 1} \\ \underbrace{f_{\pNTK}(\xt)}_{k \times 1} &= \underbrace{f_0(\xt)}_{k \times 1} + \Big( \underbrace{\pNTK_{\w_0}(\xt, \Xs)}_{1 \times N} \,\underbrace{\pNTK_{\w_0}(\Xs, \Xs)^{-1}}_{N \times N} \,\underbrace{(\ys - f_0(\Xs))}_{N \times k} \Big)\tp \end{align*}
  • We have \norm{f_{\eNTK}(\xt) - f_{\pNTK}(\xt)} = \bigO(\frac{1}{\sqrt h}) again
    • If we add regularization, need to “scale” \lambda between the two

Kernel regression with pNTK

pNTK speed-up

pNTK speed-up on active learning task

pNTK for full CIFAR-10 regression

  • \eNTK(\Xs, \Xs) on CIFAR-10: 1.8 terabytes of memory
  • \pNTK(\Xs, \Xs) on CIFAR-10: 18 gigabytes of memory
  • Worse than infinite NTK for FCN/ConvNet (where they can be computed, if you try hard)
  • Way worse than SGD

Recap

eNTK is a good tool for intuitive understanding of the learning process

eNTK is practically very effective at “lookahead” for active learning

You should probably use pNTK instead of eNTK for high-dim output problems: