第7章:生成模型
7.1 生成对抗网络(GAN)
1. 基本概念
生成对抗网络(Generative Adversarial Network, GAN)是由Ian Goodfellow等人于2014年提出的一种生成模型框架。其核心思想是通过两个神经网络——**生成器(Generator)和判别器(Discriminator)**的对抗训练,实现数据生成。
- 生成器:从随机噪声中生成逼真数据(如图像、文本)。
- 判别器:区分生成数据与真实数据,输出概率值(0到1)。
2. 数学原理
GAN的优化目标是一个极小极大博弈(Minimax Game):
其中:
- :真实数据分布
- :噪声分布(如高斯分布)
- :判别器对真实数据的输出
- :生成器生成的假数据
3. 训练过程
- 固定生成器,训练判别器:最大化区分真实与生成数据的能力。
- 固定判别器,训练生成器:最小化判别器对生成数据的识别准确率。
- 交替迭代直至纳什均衡。
4. 经典变体
| 模型 | 核心改进 | 应用场景 |
|---|---|---|
| DCGAN | 使用卷积层和批量归一化 | 图像生成 |
| WGAN | 用Wasserstein距离替代JS散度 | 稳定训练 |
| CycleGAN | 引入循环一致性损失 | 图像风格迁移 |
5. 代码示例(PyTorch)
# 生成器定义
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.main = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 784),
nn.Tanh() # 输出归一化到[-1,1]
)
def forward(self, z):
return self.main(z)
6. 挑战与解决方案
- 模式崩溃(Mode Collapse):生成器仅产生单一类型样本。
解决方案:使用Mini-batch判别或Unrolled GAN。 - 训练不稳定:判别器过强导致梯度消失。
解决方案:采用WGAN-GP的梯度惩罚。
7. 应用案例
- 艺术创作:生成虚构人脸(如ThisPersonDoesNotExist.com)
- 数据增强:医疗图像合成以解决小样本问题
- 超分辨率:ESRGAN提升图像分辨率
扩展阅读:
此内容包含理论推导、实现细节和实用资源,符合技术书籍的专业性要求。如需增加具体案例的代码或数学证明细节,可进一步扩展。