基础 · 13

循环神经网络与 LSTM

序列有记忆,但记忆会衰退。LSTM 用三道门控住信息流——遗忘、写入、输出,让梯度穿越百步不消失。

14 min read

循环神经网络

前馈神经网络一次只看一个输入,没有"时间"的概念。但语言、语音、股价都是序列——当前时刻的含义依赖之前发生了什么。循环神经网络(RNN)的核心想法很简单:

给网络加一个"记忆":把上一时刻的隐状态传给下一时刻。

具体来说,Elman RNN 在每个时间步 tt 做两件事:

ht=tanh(Wxhxt+Whhht1+bh)h_t = \tanh(W_{xh} \cdot x_t + W_{hh} \cdot h_{t-1} + b_h) yt=Whyhty_t = W_{hy} \cdot h_t

hth_t 是隐状态——网络的"记忆"。它同时接收当前输入 xtx_t 和前一步的记忆 ht1h_{t-1},用 tanh 压到 (1,1)(-1, 1) 之间。yty_t 是当前时刻的输出。

按时间展开

把一个 RNN 沿时间轴"铺平",就得到一个很深的前馈网络——每个时间步是一层,所有层共享同一组权重 WxhW_{xh}WhhW_{hh}。这种视角叫 unrolling,它让我们能用熟悉的反向传播来训练 RNN。

时间反向传播 (BPTT)

既然展开后就是一个深度网络,训练方式自然是反向传播。对 RNN 而言,这叫 BPTT(Backpropagation Through Time):从最后一个时间步的损失出发,沿时间轴一路乘回去。

损失对隐状态的梯度需要逐步回传:

Lht=LhTk=t+1Thkhk1\frac{\partial L}{\partial h_t} = \frac{\partial L}{\partial h_T} \prod_{k=t+1}^{T} \frac{\partial h_k}{\partial h_{k-1}}

关键在这个连乘。每一步的局部导数是:

hkhk1=diag(1hk2)Whh\frac{\partial h_k}{\partial h_{k-1}} = \text{diag}(1 - h_k^2) \cdot W_{hh}

其中 diag(1hk2)\text{diag}(1 - h_k^2) 是 tanh 的导数(最大值为 1,通常远小于 1)。

梯度消失问题

问题来了:这个连乘要乘 TtT - t 次。

  • 如果 Whh\|W_{hh}\| 的谱范数 < 1,连乘指数衰减 → 梯度消失。早期时间步的信号传不到后面,网络"忘了"远处的输入。
  • 如果 Whh\|W_{hh}\| 的谱范数 > 1,连乘指数增长 → 梯度爆炸。参数更新一步跳出几百倍,训练直接崩溃。

梯度爆炸可以用梯度裁剪(gradient clipping)粗暴解决——超过阈值就缩放。但梯度消失没有这么简单的补丁。它意味着:

简单 RNN 无法学习长距离依赖。 实际上超过 10–20 步,梯度就已经接近零了。

这就是 LSTM 被发明的原因。

LSTM

LSTM(Long Short-Term Memory) 的核心思想是:给记忆一条"高速公路"——细胞状态 CtC_t。信息可以在这条路上几乎无损地流过多个时间步,不被反复乘以权重矩阵。

为了控制什么信息进入、保留、输出,LSTM 引入三道(gate),每道门的输出都在 [0,1][0,1] 之间(sigmoid),然后和信号做逐元素乘法:

遗忘门:该丢掉什么

ft=σ(Wf[ht1,xt]+bf)f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f)

ftf_t 接近 0 表示"清除这段记忆",接近 1 表示"完整保留"。当话题切换时,遗忘门会关闭。

输入门 + 候选记忆:该写入什么

it=σ(Wi[ht1,xt]+bi)i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) C~t=tanh(Wc[ht1,xt]+bc)\tilde{C}_t = \tanh(W_c \cdot [h_{t-1}, x_t] + b_c)

输入门决定"写多少",候选记忆是"写什么内容"。两者逐元素相乘后加到细胞状态上。

细胞状态更新

Ct=ftCt1+itC~tC_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t

这就是那条"高速公路":旧记忆乘以遗忘门(保留多少),加上新信息乘以输入门(写入多少)。梯度沿 CC 回传时,只需乘以 ftf_t——只要遗忘门接近 1,梯度几乎无损地流过

输出门:该输出什么

ot=σ(Wo[ht1,xt]+bo)o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) ht=ottanh(Ct)h_t = o_t \odot \tanh(C_t)

细胞状态是完整的记忆,但不是所有记忆都需要在当前时刻暴露。输出门控制隐状态 hth_t 呈现记忆的哪个部分。

为什么 LSTM 解决了梯度消失

关键在细胞状态的更新公式 Ct=ftCt1+C_t = f_t \odot C_{t-1} + \ldots。对 Ct1C_{t-1} 求导只得到 ftf_t——一个接近 1 的标量。这意味着梯度可以沿细胞状态线性流过任意长度,不像简单 RNN 那样指数衰减。

用一个比喻:简单 RNN 的记忆是在纸上反复涂写——每一步都把整张纸重新画一遍,旧内容很快模糊。LSTM 的细胞状态是一块白板,只有拿到"遗忘门钥匙"才能擦,拿到"输入门钥匙"才能写。

点击 Step 开始
LSTM CellCell State C_t = 0.000.00遗忘门 f0.00输入门 i0.00候选 C̃0.00输出门 of_t = σ(W_f·[h_t-1, x_t] + b_f)i_t = σ(W_i·[h_t-1, x_t] + b_i)C̃_t = tanh(W_c·[h_t-1, x_t] + b_c)o_t = σ(W_o·[h_t-1, x_t] + b_o)点击 "Step" 逐步观察 LSTM 门控变化

注意第 5 步 "很":话题切换时遗忘门关闭 (f≈0.3),清除旧记忆并写入新信息

LSTM 门控可视化 — 模拟值展示门控如何控制信息流

GRU:简化版 LSTM

GRU(Gated Recurrent Unit) 是 Cho 等人在 2014 年提出的简化方案。它把遗忘门和输入门合并为一个更新门 ztz_t,同时去掉了独立的细胞状态:

zt=σ(Wz[ht1,xt])z_t = \sigma(W_z \cdot [h_{t-1}, x_t]) rt=σ(Wr[ht1,xt])r_t = \sigma(W_r \cdot [h_{t-1}, x_t]) h~t=tanh(W[rtht1,xt])\tilde{h}_t = \tanh(W \cdot [r_t \odot h_{t-1}, x_t]) ht=(1zt)ht1+zth~th_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t

设计哲学:ztz_t 同时控制"忘多少旧的"和"写多少新的"——忘掉的部分正好被新信息填上,两者互补。重置门 rtr_t 决定计算候选值时参考多少旧隐状态。

GRU 参数更少、训练更快,在中等长度序列上表现与 LSTM 相当。但在需要非常精细的记忆控制(如代码生成、长文档)时,LSTM 仍有优势。

这个想法在前沿里