Understanding neural network optimizers through matrix preconditioning
how and why we build algorithms like AdamW and Muon
quantifying error in systems of linear equations
A typical course on linear algebra teaches us that a core operation is the solution to Ax=b
, that is, finding the values of the vector x such that under the linear transformation A we get the vector b.
When solving this problem with a computer, our floating point numbers are inexact. We’ve carefully built floating point systems to try and minimize truncation and rounding error from representing a real number with a limited number of bits. For example, the bfloat16 system uses 16 bit precision but provides more dynamic range around zero since we typically normalize our data around zero. But even carefully designed systems will accumulate error. A natural question to ask is how much does the imprecision of our floating points affect the solution to the system Ax=b
. If we imagine a small perturbation of b
, we can consider instead the solution to A(x + δx) = b + δb
. I won’t derive it here, but using the submultiplicative property and under any L_p matrix norm, we have that
k(A) is called the matrix conditioning number of A, and tells us how much error amplification you expect to see in x bounded by the error in b.

When k(A) = 1, the ideal case, the error in the solution is bounded above by the error in the input. If k(A) is large, for example 100, the error in the solution could be orders of magnitude worse than the error in the input. For this reason, we generally try and avoid solving systems with poor conditioning. If you’re ever played around with an algorithm that does any sort of matrix decomposition, you may have seen error messages like “matrix X is ill-conditioned, could not converge”. This is because iterative solvers error exponentially with poor conditioning.
It turns out that for the matrix 2-norm we have a nice definition in terms of the singular values for a general matrix, and the eigenvalues if the matrix is normal (in particular, symmetric).
Intuitively, the singular values tell us how much our transformation A stretches space along orthogonal directions. So we can think of the condition number k(A) telling us the ratio between the most stretched and least stretched dimensions when A is full rank.
This is a very useful identity since it allows us to understand error in our system by analyzing the spectral properties of A.
checkpoint 1.
The matrix conditioning number k(A) gives us a bound on the error amplification we can expect in the solution to Ax=b if we have some perturbation in b. The matrix conditioning number can also be defined as the ratio of the largest to smallest eigenvalues for normal matrices.
preconditioning in systems of equations
We want to avoid solving systems with very high conditioning numbers since we expect a lot of error in the solution due to irreducible error in our floating points. To do so, we can try and solve a new system of equations. Take Ax=b and apply one or more matrices to both sides.
I’ll refer to left-preconditioning for the rest of this article since it naturally arises in gradient descent. If M improves the condition number of A
, our solver will converge faster. We call this left-multiplication by M
“preconditioning” and it has been an active area in applied mathematics for many decades. With a good preconditioner, we can reduce the precision of our solver since truncation/rounding errors won’t accumulate as much error. In practice it’s common to never explicitly construct M, but rather use it implicitly in iterative solvers. Remember that the conditioning number under the relates to the spectral properties of A for normal matrices. Keep this in mind, since we’ll have to consider the spectral properties of M as well.
back to gradient descent
Let’s take a step back and think about why gradient-based optimization for neural networks exists at all. Given parameters theta of a neural network f
, we try and find the parameters that minimize the average error across our data.
To find this minima, we could try and directly solve ∇L = 0
, but this is impossible since ∇L
of a neural network is non-linear in terms of each parameter θᵢ
. Instead, we start with random θₜ and use the gradient — which locally gives the direction of steepest ascent — to iteratively decrease our loss. This gives rise to stochastic gradient descent,
where we write the empirical expectation since we compute this over a subset of our data (a batch). This requires choosing a step size α, telling us how far along the loss surface to move given only a locally linear approximation around the point θₜ.
The issue with the gradient being a linear approximation around θₜ is that it only tells us the direction of steepest descent, and nothing about the curvature, i.e. how far to travel along our loss surface in each direction. The Hessian, i.e. the second derivative of our error L, tells us the local curvature around our current parameters θₜ.
This is exactly what gives rise to “second-order” optimizers: by using the Hessian to scale the learning rate in each direction along the loss surface, to reduce how much we may “overshoot” from our linear approximation in directions that require a smaller step size.
checkpoint 2
(Stochastic) gradient descent takes the same step size along all gradient directions even though parameters may have difference variance and be correlated with each other. Using curvature information about our loss function would be useful.
Using the Hessian to improve optimizers
We know that the Hessian gives us local curvature information about our loss surface around a point θₜ, b how do we integrate this into our gradient descent algorithm? Like many things in optimization, it comes from the fact that we can approximate a smooth function locally around a point by considering the Taylor series expansion. The second-order Taylor series of our loss function L around the point θₜ is

To find the minimum of this approximation, we compute the derivative, set it to zero and define our iterator
This second-order algorithm is called Newton’s method and has the useful property of one-step convergence for quadratic functions since the Taylor expansion is exact. However, even for inexact functions like deep and massively nonlinear neural networks, we know that around our current parameters our approximation is locally quadratic, which leads to damped Newton’s method, i.e.
The Hessian is symmetric by construction (since mixed partial derivatives commute), so we have the conditioning number k(H) tells us the ratio between the flattest and steepest curvatures, or how anisotropic our approximation is around a point. In particular the eigenvalues tell us the curvature along each eigenspace direction, so using matrix decomposition we can measure the curvature along a particular gradient direction,
We also know that the larger the conditioning number, the more sensitive our linear system is to small errors in b and so we expect in our Newton update to have more error when the ratio of eigenvalues of our Hessian is large. That is to say: when we need the inverse Hessian the most, it becomes even more difficult to solve for due to floating point and truncation error.
checkpoint 3
The Hessian (second derivative) gives us curvature information for each parameter. By transforming the gradient a point θₜ by the inverse Hessian, we take larger steps along flat directions and smaller steps along steep directions.
most optimizers can be explained as preconditioned SGD
Even if we wanted to use the Hessian directly, in large-scale nonlinear systems it’s computationally impossible. But thinking about the Hessian as a preconditioner for gradient descent that rescales each gradient direction proportional to the local curvature forms the fundamental goal behind every modern neural network optimizer: reducing the variance along gradient directions. The preconditioner that we choose defines many common algorithms like AdamW, Shampoo, Muon, SPlus and more. These methods are often called “whitening” methods, since “whitening” the gradient transforms the gradient space such that the effective covariance becomes the identity, i.e. there is unit variance in all directions.
Consider a general preconditioned gradient descent
Then we have that
I find it quite useful to think of optimizer design through the lens of preconditioning. AdamW, for example, uses a diagonal preconditioner that effectively scales each parameter independently so the learning rate is proportional to the variance along that particular direction. I’ll write more about optimizers in another blog post, but hopefully this helped you define a consistent frame for thinking about optimizer design.