第5章:分数生成模型(Score-Based Generative Models)
条件生成与指导采样(Classifier-Guided Sampling)
1. 核心思想
条件生成与指导采样(Classifier-Guided Sampling)是一种通过外部分类器(或条件信息)引导扩散模型生成过程的技术。其核心思想是:
- 条件控制:在生成过程中引入类别标签、文本描述或其他辅助信息,实现对生成样本的精确控制。
- 梯度引导:利用分类器输出的梯度信息调整逆向扩散过程的噪声预测方向,使生成样本满足特定条件。
数学上,该技术通过修改分数函数(score function)实现:
其中:
- 是无条件扩散模型的分数
- 是分类器给出的条件概率
- 是目标条件(如类别标签)
2. 实现步骤
(1)分类器训练
需预训练一个分类器 ,能够对噪声数据 预测条件 。关键点:
- 分类器需在不同噪声级别的数据上训练(与扩散过程同步)
- 通常使用与扩散模型相同的U-Net结构,但输出改为类别概率
(2)采样过程改进
在逆向扩散的每一步中:
- 计算无条件分数
- 计算分类器梯度
- 组合分数并调整步长:
其中 是指导强度系数(guidance scale)
3. 代码示例(PyTorch伪代码)
def classifier_guided_sample(model, classifier, y, s=1.0):
# 初始化噪声
x_T = torch.randn(shape)
for t in reversed(range(T)):
# 无条件噪声预测
eps_uncond = model(x_t, t)
# 分类器梯度计算
with torch.enable_grad():
x_in = x_t.detach().requires_grad_(True)
logits = classifier(x_in, t)
log_probs = F.log_softmax(logits, dim=-1)
selected = log_probs[range(batch_size), y]
grad = torch.autograd.grad(selected.sum(), x_in)[0]
# 组合分数
eps = eps_uncond - s * (1 - alpha_bar[t])**0.5 * grad
x_{t-1} = denoising_step(x_t, eps, t)
return x_0
4. 应用案例
(1)ImageNet条件生成
- 使用预训练的ResNet分类器引导256×256图像生成
- 指导强度 时,FID从3.85降至2.97(DDPM论文结果)
(2)文本到图像生成
- 在Stable Diffusion中,CLIP文本编码器充当隐式分类器
- 通过调整指导系数控制生成图像与文本的匹配度
5. 技术优势与局限
| 优势 | 局限 |
|---|---|
| 无需重新训练扩散模型 | 需额外训练噪声鲁棒分类器 |
| 可灵活调整控制强度 | 高指导系数可能导致样本质量下降 |
| 兼容离散和连续条件 | 分类器与扩散模型的噪声分布需对齐 |
6. 扩展变体
- 无分类器指导(Classifier-Free Guidance):通过联合训练条件/无条件模型避免分类器依赖
- 多条件融合:支持文本+类别+分割图等多模态条件输入
图表建议
- 逆向扩散过程梯度修正示意图(对比原始采样与指导采样路径)
- 不同指导系数下的生成样本对比网格
- 分类器结构与传统分类器的噪声适应对比图
