Linear Algebra
15.67 min read

The Splitter Node Rule

Splitter nodes handle the case where one value feeds multiple downstream computations. A splitter takes input ss and produces copies x1,x2x_1, x_2. In the backward pass, it receives two gradient numbers — one from each output wire. Its job: compute the gradient to place on input wire ss.

There are two paths from ss to the output ff (through x1x_1 and through x2x_2). When a variable affects the output through multiple routes, the total derivative is the sum of each route's contribution — the multivariate chain rule.

Since the splitter just copies (x1=x2=sx_1 = x_2 = s, local partial = 1 for each), the gradient on ss is simply: add the two incoming gradients. More generally, if ss fans out to kk copies, add all kk incoming gradients.

This sum rule explains why parameters that appear many times in a model receive large accumulated gradients.

Formal View

Lemma 15.2 — Splitter Node Lemma
Let a splitter take input ss and produce copies x1,x2x_1, x_2. Then:
f(s,t)s(s0,t0)=f(x1,t)x1([x1]0,t0)+f(x2,t)x2([x2]0,t0)\frac{\partial f(s, \mathbf{t})}{\partial s}(s_0, \mathbf{t}_0) = \frac{\partial f(x_1, \mathbf{t})}{\partial x_1}([x_1]_0, \mathbf{t}_0) + \frac{\partial f(x_2, \mathbf{t})}{\partial x_2}([x_2]_0, \mathbf{t}_0)

Two paths from ss to output → sum the two gradient contributions. The splitter's local Jacobian is [1,1]t[1, 1]^t, so chain rule gives this sum.

Why This Matters

The sum rule explains why shared parameters (weight tying, attention heads) accumulate gradients from all their uses.

  • Shared weights in RNNs: same matrix used at each timestep, gradients sum over all steps
  • Residual connections (skip connections) in ResNets: gradient flows through two paths and adds
  • Attention mechanisms reuse the same key/query/value matrices
  • Gradient accumulation for large effective batch sizes exploits this additive structure

Quiz

Question 1

Splitter sends t1t_1 to nodes AA and BB. Node AA sends back gradient 3; node BB sends back 7. Gradient on t1t_1?

Question 2

Why does the splitter add (not multiply) incoming gradients?

Question 3

A splitter with kk outputs must collect all kk upstream gradients before it can send a gradient backward.

Question 4

A residual block computes y=F(x)+xy = F(x) + x. Using the splitter rule, yx\frac{\partial y}{\partial x} is:

Question 5

If t1t_1 feeds a single downstream node (no splitter), the function node rule applies directly to t1t_1.

Common Mistakes

  • Swapping the rules: adding at function nodes (should multiply) and multiplying at splitters (should add).
  • Forgetting the splitter must wait for ALL downstream gradients before computing its sum.
  • Taking the average of incoming gradients rather than the sum.