第5章:循环神经网络(RNN)
RNN的基本原理
1. 什么是循环神经网络(RNN)?
循环神经网络(Recurrent Neural Network, RNN)是一类专门用于处理序列数据的神经网络。与传统的前馈神经网络不同,RNN通过引入时间维度上的循环连接,使网络能够保留历史信息,从而实现对序列数据的建模。
核心特点:
- 时间展开性:RNN在每个时间步共享相同的权重参数。
- 隐状态(Hidden State):用于存储历史信息的内部记忆单元。
- 动态输入输出:支持变长序列输入和输出(如文本、时间序列)。
2. RNN的基本结构
2.1 展开计算图
RNN的计算过程可以通过时间展开表示。对于一个输入序列 ,RNN在每个时间步 执行以下操作:
- 输入层:接收当前时间步的输入 。
- 隐状态更新:
- :当前隐状态
- 、:权重矩阵
- :偏置项
- :激活函数(如Tanh或ReLU)
- 输出层(可选):
- :输出激活函数(如Softmax分类任务)
2.2 循环连接的直观理解
- 记忆能力:隐状态 是过去所有输入信息的压缩表示。
- 参数共享:所有时间步共享同一组参数(),减少模型复杂度。
3. RNN的数学表达
3.1 前向传播公式
对于时间步 :
- 隐状态:
- 输出:
3.2 反向传播(BPTT)
RNN通过**随时间反向传播(Backpropagation Through Time, BPTT)**计算梯度。由于时间步之间存在依赖关系,梯度会沿时间链式传播,可能导致:
- 梯度消失:长序列中较早时间步的梯度趋近于零。
- 梯度爆炸:梯度值指数级增长。
4. RNN的局限性
4.1 短期记忆问题
- 问题:标准RNN难以捕获长期依赖关系(如相隔几十步的文本依赖)。
- 原因:梯度消失导致早期时间步的信息无法有效传递。
4.2 解决方案的引入
后续改进模型(如LSTM、GRU)通过门控机制缓解此问题(见下一节)。
5. 代码示例(PyTorch实现)
import torch
import torch.nn as nn
class SimpleRNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
# x: (batch_size, seq_len, input_size)
out, h_n = self.rnn(x) # out: (batch_size, seq_len, hidden_size)
out = self.fc(out[:, -1, :]) # 取最后一个时间步的输出
return out
6. 应用场景
- 自然语言处理:语言建模、机器翻译。
- 时间序列预测:股票价格、气象数据。
- 语音识别:音频信号序列建模。
总结
RNN通过循环连接实现对序列数据的建模,但其基础结构存在短期记忆缺陷。后续章节将介绍LSTM、GRU等改进模型,以及更强大的Transformer架构。
