Entry 23 of 24
ML Fundamentals Series
·2 min read

Backpropagation Through Time Is Just Backprop That Remembers Every Timestep

An RNN's hidden state at time tt depends on the hidden state at t1t-1, which depends on t2t-2, all the way back to the start of the sequence. That chain is what gives the network memory, and it's also exactly what makes training one harder than training a feedforward network. Backpropagation Through Time (BPTT) is ordinary backpropagation, but applied across an unfolded sequence instead of across static layers.

The forward equations are simple to state. At each timestep, the hidden state combines the current input with the previous hidden state, then applies an activation function:

St=g1(WxXt+WsSt1)Yt=g2(WySt)S_t = g_1(W_x X_t + W_s S_{t-1}) \qquad Y_t = g_2(W_y S_t)

The error at time tt compares the actual output against the desired one: E=(dtYt)2E = (d_t - Y_t)^2.

Updating the weights is where BPTT earns its name. There are three weight matrices to adjust, and each one gets a different treatment because each one is used differently across the unfolded sequence. The output weight WyW_y only affects the current timestep's output directly, so its gradient is a simple two-term chain rule: E3/Wy=(E3/Y3)(Y3/Wy)\partial E_3 / \partial W_y = (\partial E_3 / \partial Y_3) \cdot (\partial Y_3 / \partial W_y).

The hidden-state weight WsW_s and input weight WxW_x are different: both were reused at every single timestep leading up to the current one, so the error at time 3 has to account for their contribution at time 1, time 2, and time 3 all at once. That gives a summed chain rule instead of a single term:

E3Ws=i=13E3Y3Y3SiSiWs\frac{\partial E_3}{\partial W_s} = \sum_{i=1}^{3} \frac{\partial E_3}{\partial Y_3} \cdot \frac{\partial Y_3}{\partial S_i} \cdot \frac{\partial S_i}{\partial W_s}

with the identical structure for E3/Wx\partial E_3 / \partial W_x, just swapping in Si/Wx\partial S_i / \partial W_x. Every term in that sum is itself a product of several derivatives chained across timesteps, since SiS_i depends on Si1S_{i-1}, which depends on Si2S_{i-2}, and so on.

That's precisely the mechanism behind BPTT's two failure modes. Multiply enough numbers smaller than 1 together across a long sequence and the gradient shrinks toward zero before it reaches the early timesteps: the vanishing gradient problem, where the network effectively forgets anything more than a few steps back. Multiply enough numbers larger than 1 and the gradient blows up instead: exploding gradients, where weight updates become unstable. Both are consequences of the same summed, multiplied chain, not separate bugs, and both are the exact motivation for the gated architectures that came next.