Linear Algebra
15.18 min read

The Gradient Problem

Suppose you have a function f(t)f(\mathbf{t}) built from hundreds of simple operations chained together — like a neural network computing its loss from its parameters. To train with gradient descent, you need every partial fti\frac{\partial f}{\partial t_i}.

The obvious approach: symbolically differentiate through the chain. But applied recursively to a complex network, the chain rule recomputes the same intermediate expressions exponentially many times. For a circuit with 1000 nodes, this can be catastrophic.

Backpropagation is smarter. It computes all partials at a fixed input t0\mathbf{t}_0 using exactly two passes over the circuit — one forward, one backward. Intermediate values computed forward get reused backward, eliminating redundancy. The total cost is roughly the same as evaluating ff twice.

Formal View

Definition 15.1 — The Gradient Problem
Given a function f(t)f(\mathbf{t}) described by a feedforward (acyclic) circuit and a fixed point t0\mathbf{t}_0, compute all partial derivatives fti(t0)\frac{\partial f}{\partial t_i}(\mathbf{t}_0) for each input tit_i.
Remark 15.1 — Neural Networks
In machine learning, t\mathbf{t} holds the trainable parameters, f(t)f(\mathbf{t}) is the loss on training data, and the partials fti\frac{\partial f}{\partial t_i} drive gradient descent updates. Backpropagation makes computing all these partials efficient — O(circuit)O(|\text{circuit}|) instead of O(ncircuit)O(n \cdot |\text{circuit}|).

Why This Matters

Every modern neural network — GPT, image classifiers, AlphaFold — is trained using backpropagation.

  • Training deep neural networks with millions of parameters efficiently
  • PyTorch autograd and JAX are implementations of this exact algorithm
  • Physics simulators and differentiable renderers that need gradients
  • Gradient-based hyperparameter optimization and meta-learning

Quiz

Question 1

Why is naive symbolic differentiation inefficient for large circuits?

Question 2

Backpropagation computes all partials fti(t0)\frac{\partial f}{\partial t_i}(\mathbf{t}_0) using:

Question 3

Backpropagation computes exact partial derivatives (not approximations) at a specific point t0\mathbf{t}_0.

Question 4

The circuit used for backpropagation must be acyclic (no loops).

Question 5

Backpropagation computes the gradient of the circuit output with respect to:

Common Mistakes

  • Confusing backpropagation with gradient descent — backprop only *computes* the gradients; gradient descent uses them to update parameters.
  • Thinking backprop produces symbolic derivatives — it computes numerical values at a specific t0\mathbf{t}_0, not symbolic formulas.
  • Assuming each input needs its own backward pass — that is exactly the naive approach backprop avoids.