Tailwind CSSTailwind CSS
Home
  • Tailwind CSS 书籍目录
  • Vue 3 开发实战指南
  • React 和 Next.js 学习
  • TypeScript
  • React开发框架书籍大纲
  • Shadcn学习大纲
  • Swift 编程语言:从入门到进阶
  • SwiftUI 学习指南
  • 函数式编程大纲
  • Swift 异步编程语言
  • Swift 协议化编程
  • SwiftUI MVVM 开发模式
  • SwiftUI 图表开发书籍
  • SwiftData
  • ArkTS编程语言:从入门到精通
  • 仓颉编程语言:从入门到精通
  • 鸿蒙手机客户端开发实战
  • WPF书籍
  • C#开发书籍
learn
  • Java编程语言
  • Kotlin 编程入门与实战
  • /python/outline.html
  • AI Agent
  • MCP (Model Context Protocol) 应用指南
  • 深度学习
  • 深度学习
  • 强化学习: 理论与实践
  • 扩散模型书籍
  • Agentic AI for Everyone
langchain
Home
  • Tailwind CSS 书籍目录
  • Vue 3 开发实战指南
  • React 和 Next.js 学习
  • TypeScript
  • React开发框架书籍大纲
  • Shadcn学习大纲
  • Swift 编程语言:从入门到进阶
  • SwiftUI 学习指南
  • 函数式编程大纲
  • Swift 异步编程语言
  • Swift 协议化编程
  • SwiftUI MVVM 开发模式
  • SwiftUI 图表开发书籍
  • SwiftData
  • ArkTS编程语言:从入门到精通
  • 仓颉编程语言:从入门到精通
  • 鸿蒙手机客户端开发实战
  • WPF书籍
  • C#开发书籍
learn
  • Java编程语言
  • Kotlin 编程入门与实战
  • /python/outline.html
  • AI Agent
  • MCP (Model Context Protocol) 应用指南
  • 深度学习
  • 深度学习
  • 强化学习: 理论与实践
  • 扩散模型书籍
  • Agentic AI for Everyone
langchain
  • 第4章:去噪扩散概率模型(DDPM)

第4章:去噪扩散概率模型(DDPM)

4.3 损失函数与训练细节

理论推导

DDPM的核心训练目标是最小化负对数似然的变分上界(ELBO),其损失函数可分解为以下三部分:

  1. 正向过程损失(常数项,通常忽略):

    LT=DKL(q(xT∣x0)∥p(xT))L_T = D_{KL}(q(x_T|x_0) \parallel p(x_T)) LT​=DKL​(q(xT​∣x0​)∥p(xT​))

  2. 逆向过程损失(关键优化项):

    Lt−1=Eq[DKL(q(xt−1∣xt,x0)∥pθ(xt−1∣xt))]L_{t-1} = \mathbb{E}_q\left[ D_{KL}(q(x_{t-1}|x_t,x_0) \parallel p_\theta(x_{t-1}|x_t)) \right] Lt−1​=Eq​[DKL​(q(xt−1​∣xt​,x0​)∥pθ​(xt−1​∣xt​))]

  3. 重构损失:

    L0=−log⁡pθ(x0∣x1)L_0 = -\log p_\theta(x_0|x_1) L0​=−logpθ​(x0​∣x1​)

通过推导可得简化后的均方误差损失(实际实现形式):

Lsimple=Et,x0,ϵ[∥ϵ−ϵθ(xt,t)∥2]L_\text{simple} = \mathbb{E}_{t,x_0,\epsilon}\left[ \| \epsilon - \epsilon_\theta(x_t,t) \|^2 \right] Lsimple​=Et,x0​,ϵ​[∥ϵ−ϵθ​(xt​,t)∥2]

其中 xt=αtx0+1−αtϵx_t = \sqrt{\alpha_t}x_0 + \sqrt{1-\alpha_t}\epsilonxt​=αt​​x0​+1−αt​​ϵ,ϵ\epsilonϵ为标准高斯噪声。

训练算法

def train_step(model, x0, optimizer):
    # 1. 随机采样时间步
    t = torch.randint(0, T, (x0.shape[0],))
    
    # 2. 生成带噪样本
    alpha_bar = compute_alpha_bar(t)  # 预计算的噪声调度系数
    eps = torch.randn_like(x0)
    xt = torch.sqrt(alpha_bar) * x0 + torch.sqrt(1-alpha_bar) * eps
    
    # 3. 预测噪声
    eps_pred = model(xt, t)
    
    # 4. 计算损失
    loss = F.mse_loss(eps_pred, eps)
    
    # 5. 反向传播
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss

关键实现细节

  1. 噪声调度策略:

    • 线性调度:βt\beta_tβt​从β1=10−4\beta_1=10^{-4}β1​=10−4到βT=0.02\beta_T=0.02βT​=0.02线性增长
    • 余弦调度(改进版):αt=f(t)f(0)\alpha_t = \frac{f(t)}{f(0)}αt​=f(0)f(t)​, f(t)=cos⁡(t/T+s1+s⋅π2)2f(t)=\cos(\frac{t/T+s}{1+s}\cdot\frac{\pi}{2})^2f(t)=cos(1+st/T+s​⋅2π​)2
  2. 模型架构选择:

    • U-Net结构(含残差连接)
    • 时间步ttt通过正弦位置编码嵌入
    • 自注意力机制(用于高分辨率生成)
  3. 训练技巧:

    • 混合精度训练(FP16)
    • 梯度裁剪(防止梯度爆炸)
    • EMA(指数移动平均)模型参数

案例研究:CIFAR-10训练

超参数典型值
Batch size128
Learning rate2e-4
Training steps500k
EMA decay0.9999
Time steps T1000

训练曲线示例:

图:噪声预测损失随训练步数的变化趋势

数学补充

逆向过程真实后验分布:

q(xt−1∣xt,x0)=N(μ~t(xt,x0),β~tI)q(x_{t-1}|x_t,x_0) = \mathcal{N}(\tilde{\mu}_t(x_t,x_0), \tilde{\beta}_t I) q(xt−1​∣xt​,x0​)=N(μ~​t​(xt​,x0​),β~​t​I)

其中:

μ~t=αt−1βt1−αtx0+αt(1−αt−1)1−αtxt\tilde{\mu}_t = \frac{\sqrt{\alpha_{t-1}}\beta_t}{1-\alpha_t}x_0 + \frac{\sqrt{\alpha_t}(1-\alpha_{t-1})}{1-\alpha_t}x_t μ~​t​=1−αt​αt−1​​βt​​x0​+1−αt​αt​​(1−αt−1​)​xt​

β~t=1−αt−11−αtβt\tilde{\beta}_t = \frac{1-\alpha_{t-1}}{1-\alpha_t}\beta_t β~​t​=1−αt​1−αt−1​​βt​

Last Updated:: 5/28/25, 11:37 PM