第5章:循环神经网络(RNN)
RNN文本生成
1. 文本生成的基本原理
- 序列建模:RNN通过记忆先前时间步的隐藏状态,捕捉文本序列中的上下文依赖关系。
- 自回归生成:每一步基于历史token预测下一个token(如字符、单词或子词),形成迭代生成过程。
- 概率分布输出:Softmax层输出词汇表上的概率分布,通过采样(如贪婪搜索、随机采样)选择下一个token。
2. 关键技术实现
2.1 模型架构
- 单层/多层RNN:LSTM或GRU单元解决长程依赖问题。
- Embedding层:将离散token映射为连续向量表示。
- 温度参数(Temperature):控制生成多样性(高温增加随机性,低温趋向确定性)。
2.2 训练流程
- 数据准备:将文本分割为固定长度的序列(如每50个字符为一个样本)。
- Teacher Forcing:训练时使用真实历史token作为输入,而非模型生成结果。
- 损失函数:交叉熵损失(Cross-Entropy)衡量预测分布与真实token的差异。
2.3 生成策略
| 方法 | 描述 | 优缺点 |
|---|---|---|
| 贪婪搜索 | 每一步选择概率最高的token | 高效但可能陷入重复模式 |
| 随机采样 | 按概率分布随机选择token | 多样性高,但可能不连贯 |
| Beam Search | 保留Top-K候选序列,逐步扩展 | 平衡质量与多样性,计算成本较高 |
3. 应用案例
3.1 诗歌生成
- 输入:古诗开头(如“春江潮水连海平”)。
- 输出:生成符合平仄和意境的后续诗句。
3.2 代码补全
- 示例:输入部分Python代码,模型预测后续语法合理的代码段。
4. 挑战与改进
- 曝光偏差(Exposure Bias):训练与生成时输入分布不一致→可通过计划采样(Scheduled Sampling)缓解。
- 重复生成:引入惩罚机制(如重复token的概率衰减)。
- 现代替代方案:Transformer(如GPT)通过自注意力机制提升长文本生成质量。
5. 代码示例(PyTorch)
import torch
import torch.nn as nn
class TextGenerator(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.rnn = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, vocab_size)
def forward(self, x, hidden):
x = self.embedding(x)
out, hidden = self.rnn(x, hidden)
return self.fc(out), hidden
# 生成函数示例
def generate_text(model, start_seq, length, temperature=1.0):
model.eval()
tokens = tokenizer.encode(start_seq)
hidden = None
for _ in range(length):
input_tensor = torch.tensor([tokens[-1]]).unsqueeze(0)
logits, hidden = model(input_tensor, hidden)
probs = torch.softmax(logits / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).item()
tokens.append(next_token)
return tokenizer.decode(tokens)
6. 扩展阅读
- 数据集:WikiText、OpenWebText等大规模语料库。
- 评估指标:BLEU(机器翻译)、Perplexity(语言模型困惑度)。
- 进阶技术:条件生成(如指定风格的文本)、强化学习微调(如人类偏好对齐)。
此内容涵盖理论、实现细节与实战示例,可根据需要增加具体任务(如对话生成)的案例分析或可视化生成过程(如注意力权重热力图)。