第5章:循环神经网络(RNN)
5.5 RNN文本生成
1. 文本生成的基本原理
文本生成是RNN的经典应用之一,其核心思想是通过学习序列数据的概率分布,预测下一个字符或单词,逐步生成连贯的文本。RNN通过隐藏状态(Hidden State)捕捉上下文信息,使生成的文本具有时序依赖性。
关键步骤
- 数据准备:将文本转换为字符级或词级的序列(如One-Hot编码或词嵌入)。
- 模型训练:输入序列的前N个字符,预测第N+1个字符。
- 生成文本:给定种子序列,迭代预测下一个字符并扩展序列。
2. 模型架构设计
2.1 字符级 vs. 词级生成
- 字符级生成:以单个字符为单元,模型更轻量但长程依赖较弱(如生成莎士比亚风格文本)。
- 词级生成:以单词为单元,依赖词嵌入(Word2Vec、GloVe),生成结果更语义化但需更大语料库。
2.2 网络结构选择
- 基础RNN:简单但易出现梯度消失问题。
- LSTM/GRU:更适合长文本生成,能捕捉长期依赖(如生成新闻标题或诗歌)。
3. 训练与生成策略
3.1 损失函数与优化
- 使用交叉熵损失衡量预测字符与真实字符的差异。
- 通过Teacher Forcing技术加速训练:将真实字符而非预测字符作为下一步输入。
3.2 生成控制技术
- Temperature Sampling:调整Softmax输出的随机性:
- 高温(>1):生成更多样化但可能不连贯的文本。
- 低温(<1):生成更保守、确定性高的文本。
- Beam Search:保留多个候选序列,选择整体概率最高的路径(常用于机器翻译)。
4. 代码示例(PyTorch)
import torch
import torch.nn as nn
class CharRNN(nn.Module):
def __init__(self, vocab_size, hidden_size):
super().__init__()
self.embed = nn.Embedding(vocab_size, hidden_size)
self.rnn = nn.LSTM(hidden_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, vocab_size)
def forward(self, x, hidden):
x = self.embed(x)
out, hidden = self.rnn(x, hidden)
return self.fc(out), hidden
# 生成函数示例
def generate_text(model, start_seq, length, temperature=1.0):
model.eval()
hidden = None
input_seq = torch.tensor([char_to_idx[c] for c in start_seq])
for _ in range(length):
output, hidden = model(input_seq.unsqueeze(0), hidden)
prob = nn.functional.softmax(output[-1] / temperature, dim=-1)
next_char = torch.multinomial(prob, 1).item()
input_seq = torch.cat([input_seq, torch.tensor([next_char])])
return ''.join([idx_to_char[i] for i in input_seq.tolist()])
5. 应用案例
5.1 创意写作
- 诗歌生成:训练模型学习押韵和节奏(如使用宋词语料库)。
- 故事续写:输入开头段落,生成后续情节。
5.2 代码补全
- 基于RNN的IDE插件(如GitHub Copilot早期版本)。
6. 挑战与改进
- 长文本一致性:Transformer(如GPT)通过自注意力机制更擅长生成长文本。
- 数据偏见:模型可能复制训练数据中的偏见(需过滤敏感内容)。
延伸阅读
- 《The Unreasonable Effectiveness of Recurrent Neural Networks》(Andrej Karpathy博客)
- Hugging Face的
transformers库实现(如GPT-2文本生成)。
---
**注**:可根据实际需求调整代码框架(如使用TensorFlow/Keras)或增加可视化生成结果对比图。