代码示例与实战指南(PyTorch/TensorFlow实现)
1. 基础扩散模型实现(PyTorch)
1.1 正向扩散过程
import torch
import torch.nn as nn
import numpy as np
def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=0.02):
"""线性噪声调度函数"""
return torch.linspace(beta_start, beta_end, timesteps)
def forward_diffusion(x0, t, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod):
"""正向扩散过程"""
noise = torch.randn_like(x0)
sqrt_alpha = sqrt_alphas_cumprod[t].reshape(-1, 1, 1, 1)
sqrt_one_minus_alpha = sqrt_one_minus_alphas_cumprod[t].reshape(-1, 1, 1, 1)
return sqrt_alpha * x0 + sqrt_one_minus_alpha * noise, noise
1.2 U-Net噪声预测模型
class Block(nn.Module):
def __init__(self, in_ch, out_ch, time_emb_dim):
super().__init__()
self.time_mlp = nn.Linear(time_emb_dim, out_ch)
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
def forward(self, x, t):
h = self.conv1(x)
time_emb = self.time_mlp(t)[:, :, None, None]
h = h + time_emb
return self.conv2(h)
class SimpleUnet(nn.Module):
"""简化的U-Net结构"""
def __init__(self):
super().__init__()
# 实现下采样和上采样模块...
def forward(self, x, timestep):
# 实现U-Net前向传播...
return predicted_noise
2. 训练循环实现
def train_diffusion(model, dataloader, optimizer, timesteps, device):
# 初始化噪声调度参数
betas = linear_beta_schedule(timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
model.train()
for epoch in range(num_epochs):
for batch in dataloader:
optimizer.zero_grad()
# 随机采样时间步
t = torch.randint(0, timesteps, (batch.size(0),), device=device)
# 正向扩散过程
noisy_images, noise = forward_diffusion(
batch, t,
torch.sqrt(alphas_cumprod),
torch.sqrt(1. - alphas_cumprod)
)
# 预测噪声
predicted_noise = model(noisy_images, t)
# 计算损失
loss = F.mse_loss(predicted_noise, noise)
loss.backward()
optimizer.step()
3. 采样过程实现
@torch.no_grad()
def sample(model, image_size, timesteps, device):
# 初始化随机噪声
img = torch.randn((1, 3, image_size, image_size), device=device)
for i in reversed(range(timesteps)):
t = torch.full((1,), i, device=device, dtype=torch.long)
# 预测噪声
predicted_noise = model(img, t)
# 计算去噪后的图像
alpha_t = alphas[t]
alpha_t_cumprod = alphas_cumprod[t]
beta_t = betas[t]
if i > 0:
noise = torch.randn_like(img)
else:
noise = torch.zeros_like(img)
img = (img - beta_t * predicted_noise / torch.sqrt(1 - alpha_t_cumprod)) / torch.sqrt(alpha_t)
img = img + torch.sqrt(beta_t) * noise
return img
4. TensorFlow实现示例
import tensorflow as tf
from tensorflow.keras import layers
class DiffusionModel(tf.keras.Model):
def __init__(self, image_size, widths, block_depth):
super().__init__()
# 定义噪声预测网络...
def call(self, x, t):
# 实现前向传播...
return predicted_noise
def train_step(model, images, optimizer, timesteps):
with tf.GradientTape() as tape:
# 随机时间步
t = tf.random.uniform([tf.shape(images)[0]], 0, timesteps, dtype=tf.int32)
# 正向扩散
noise = tf.random.normal(tf.shape(images))
noisy_images = add_noise(images, noise, t)
# 预测噪声
predicted_noise = model(noisy_images, t, training=True)
# 计算损失
loss = tf.reduce_mean(tf.square(predicted_noise - noise))
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
return loss
5. 实战案例:图像生成
5.1 数据预处理
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
5.2 训练与采样
# 初始化模型
model = SimpleUnet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# 训练
train_diffusion(model, dataloader, optimizer, timesteps=1000, device=device)
# 采样生成图像
generated_image = sample(model, image_size=64, timesteps=1000, device=device)
6. 进阶技巧
6.1 加速采样(DDIM)
@torch.no_grad()
def ddim_sample(model, image_size, timesteps, eta=0.0):
# 实现DDIM采样算法...
pass
6.2 分类器引导
def classifier_guided_sample(model, classifier, x, t, guidance_scale=2.0):
# 实现分类器引导采样...
pass
7. 可视化工具
import matplotlib.pyplot as plt
def plot_images(images, n_rows=1):
"""绘制生成图像"""
plt.figure(figsize=(10, 10))
for i in range(len(images)):
plt.subplot(n_rows, len(images)//n_rows, i+1)
plt.imshow(images[i].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5)
plt.axis('off')
plt.show()
8. 性能优化技巧
- 混合精度训练:
scaler = torch.cuda.amp.GradScaler()
with torch.cuda.amp.autocast():
predicted_noise = model(noisy_images, t)
loss = F.mse_loss(predicted_noise, noise)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
- 分布式训练:
model = nn.DataParallel(model)
- 梯度检查点:
from torch.utils.checkpoint import checkpoint
# 在U-Net的forward中使用
h = checkpoint(self.mid_block, h, t)
9. 扩展阅读与资源
提示:实际应用中建议使用成熟的扩散模型库(如HuggingFace Diffusers)作为基础,再根据需求进行定制开发。
