第7章:生成模型
7.1 生成对抗网络(GAN)
1. 基本概念
生成对抗网络(Generative Adversarial Network, GAN) 是由Ian Goodfellow等人于2014年提出的一种生成模型框架。其核心思想是通过两个神经网络——生成器(Generator) 和 判别器(Discriminator) 的对抗训练,学习数据的分布并生成逼真的样本。
- 生成器(G):输入随机噪声(如高斯分布),输出与真实数据分布相似的样本(如图像、文本等)。
- 判别器(D):输入真实数据或生成器生成的样本,输出一个概率值(0到1),表示输入数据是真实数据的可能性。
2. 数学原理
GAN的训练过程可以看作一个极小极大博弈(Minimax Game),目标函数如下:
其中:
- 是真实数据的分布。
- 是噪声分布(如高斯分布)。
- 是判别器对真实数据的输出。
- 是生成器生成的样本。
3. 训练过程
- 固定生成器,训练判别器:
判别器尝试最大化对真实数据和生成数据的分类准确率。 - 固定判别器,训练生成器:
生成器尝试生成更逼真的样本以“欺骗”判别器。 - 交替迭代:
重复上述步骤,直到生成器生成的样本与真实数据难以区分(纳什均衡)。
4. GAN的变体与改进
由于原始GAN存在训练不稳定、模式崩溃(Mode Collapse)等问题,研究者提出了多种改进模型:
- DCGAN(Deep Convolutional GAN):使用卷积神经网络改进生成器和判别器结构。
- WGAN(Wasserstein GAN):通过Wasserstein距离优化损失函数,提升训练稳定性。
- CycleGAN:用于无监督的图像到图像转换(如风格迁移)。
5. 应用场景
- 图像生成:生成逼真的人脸、风景等(如StyleGAN)。
- 数据增强:为小样本任务生成合成数据。
- 超分辨率重建:从低分辨率图像生成高分辨率版本。
- 艺术创作:生成绘画、音乐等。
6. 代码示例(PyTorch)
import torch
import torch.nn as nn
# 定义生成器和判别器
class Generator(nn.Module):
def __init__(self, latent_dim):
super().__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 784),
nn.Tanh() # 输出范围[-1, 1]
)
def forward(self, z):
return self.model(z)
# 训练循环(伪代码)
for epoch in range(epochs):
for real_data in dataloader:
# 训练判别器
d_optimizer.zero_grad()
real_loss = criterion(D(real_data), real_labels)
fake_data = G(torch.randn(batch_size, latent_dim))
fake_loss = criterion(D(fake_data.detach()), fake_labels)
d_loss = real_loss + fake_loss
d_loss.backward()
d_optimizer.step()
# 训练生成器
g_optimizer.zero_grad()
g_loss = criterion(D(fake_data), real_labels)
g_loss.backward()
g_optimizer.step()
7. 挑战与局限性
- 训练不稳定:生成器和判别器的平衡难以控制。
- 模式崩溃:生成器可能只生成少数几种样本。
- 评估困难:缺乏统一的量化指标(常用Inception Score或FID)。
下一节:7.2 变分自编码器(VAE)
