DDPM的核心训练目标是最小化负对数似然的变分上界(ELBO),其损失函数可分解为以下三部分:
正向过程损失(常数项,通常忽略):
LT=DKL(q(xT∣x0)∥p(xT))
逆向过程损失(关键优化项):
Lt−1=Eq[DKL(q(xt−1∣xt,x0)∥pθ(xt−1∣xt))]
重构损失:
L0=−logpθ(x0∣x1)
通过推导可得简化后的均方误差损失(实际实现形式):
Lsimple=Et,x0,ϵ[∥ϵ−ϵθ(xt,t)∥2]
其中 xt=αtx0+1−αtϵ,ϵ为标准高斯噪声。
def train_step(model, x0, optimizer):
t = torch.randint(0, T, (x0.shape[0],))
alpha_bar = compute_alpha_bar(t)
eps = torch.randn_like(x0)
xt = torch.sqrt(alpha_bar) * x0 + torch.sqrt(1-alpha_bar) * eps
eps_pred = model(xt, t)
loss = F.mse_loss(eps_pred, eps)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss
噪声调度策略:
- 线性调度:βt从β1=10−4到βT=0.02线性增长
- 余弦调度(改进版):αt=f(0)f(t), f(t)=cos(1+st/T+s⋅2π)2
模型架构选择:
- U-Net结构(含残差连接)
- 时间步t通过正弦位置编码嵌入
- 自注意力机制(用于高分辨率生成)
训练技巧:
- 混合精度训练(FP16)
- 梯度裁剪(防止梯度爆炸)
- EMA(指数移动平均)模型参数
| 超参数 | 典型值 |
|---|
| Batch size | 128 |
| Learning rate | 2e-4 |
| Training steps | 500k |
| EMA decay | 0.9999 |
| Time steps T | 1000 |
训练曲线示例:
图:噪声预测损失随训练步数的变化趋势
逆向过程真实后验分布:
q(xt−1∣xt,x0)=N(μ~t(xt,x0),β~tI)
其中:
μ~t=1−αtαt−1βtx0+1−αtαt(1−αt−1)xt
β~t=1−αt1−αt−1βt