Linear Algebra
15.510 min read

The Function Node Rule

Here is the core of backpropagation: how does a function node pass gradient information backward?

Node xx computes x(s)x(\mathbf{s}) from inputs s1,,sks_1, \ldots, s_k. During the backward pass, it receives a number along its output wire — the "upstream gradient" f(x,t)x(x0,t0)\frac{\partial f(x, \mathbf{t})}{\partial x}(x_0, \mathbf{t}_0). Its job: produce a gradient for each input wire sis_i.

The answer is a product: upstream gradient × local gradient. The local gradient for input sis_i is x(s)si(s0)\frac{\partial x(\mathbf{s})}{\partial s_i}(\mathbf{s}_0), computable from the stored s0\mathbf{s}_0. The node multiplies and sends the result backward along wire sis_i.

This is just the chain rule, applied one node at a time. Each node does a tiny local computation — no node needs to know the structure of the rest of the circuit.

Formal View

Lemma 15.1 — Function Node Lemma
Let node xx have inputs s\mathbf{s} and compute x(s)x(\mathbf{s}). Then:
f(si,t)si([si]0,t0)=f(x,t)x(x0,t0)upstream gradientx(s)si(s0)local gradient\frac{\partial f(s_i, \mathbf{t})}{\partial s_i}([s_i]_0, \mathbf{t}_0) = \underbrace{\frac{\partial f(x, \mathbf{t})}{\partial x}(x_0, \mathbf{t}_0)}_{\text{upstream gradient}} \cdot \underbrace{\frac{\partial x(\mathbf{s})}{\partial s_i}(\mathbf{s}_0)}_{\text{local gradient}}

One application of the multivariate chain rule. The product has exactly two factors — upstream (from the right) and local (computed from stored s0\mathbf{s}_0).

Interactive Visualization

Backpropagation Circuit

Why This Matters

Every PyTorch operation (matmul, ReLU, softmax, etc.) implements exactly this rule in its .backward() method.

  • ReLU backward: local gradient is 1 if input > 0, else 0
  • Multiply node: gradient to each input = other input × upstream gradient
  • Sigmoid backward: local gradient is σ(1σ)\sigma(1 - \sigma), evaluated at the stored input
  • Matrix multiply backward: gradient = upstream × transpose of the other matrix factor

Quiz

Question 1

Node x=s1s2x = s_1 \cdot s_2 receives upstream gradient δ=5\delta = 5. Stored: s1,0=3s_{1,0} = 3, s2,0=4s_{2,0} = 4. What gradient goes on wire s1s_1?

Question 2

For node x=sin(s)x = \sin(s), stored s0=π/2s_0 = \pi/2, upstream gradient δ=2\delta = 2. What gradient goes on wire ss?

Question 3

What two values does a function node multiply together in the backward pass?

Question 4

A function node with two inputs sends the same gradient to both input wires.

Question 5

Why does the function node need the stored s0\mathbf{s}_0?

Common Mistakes

  • Multiplying by the forward value x0x_0 instead of by the local derivative xsi\frac{\partial x}{\partial s_i}.
  • Sending the same gradient to all input wires — each gets a different local partial.
  • Confusing the upstream gradient (arriving from the right) with the gradient being sent leftward (what the node produces).