PyTorch是由Facebook AI Research (FAIR)开发的开源深度学习框架,以其动态计算图和直观的接口设计而闻名。它结合了灵活的研究原型设计和高效的生产部署能力,已成为学术界和工业界的主流选择。
- 基于tape-based的自动微分系统
- 支持即时图构建和修改(eager execution)
- 调试友好,可直接使用Python原生控制流
- GPU加速的N维数组计算
- 丰富的张量操作库(600+操作符)
- 与NumPy兼容的API设计(
torch.from_numpy())
nn.Module基类实现层/模型封装- 包含200+预定义层和损失函数
- 支持自定义可微分函数(通过
Function类)
| 组件 | 功能描述 |
|---|
| torch | 基础张量库(类似NumPy) |
| torch.nn | 神经网络构建模块 |
| torch.optim | 优化算法实现 |
| torch.utils.data | 数据加载与预处理工具 |
| torchvision | 计算机视觉专用工具 |
- 数据准备
from torch.utils.data import DataLoader
dataset = CustomDataset()
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
- 模型定义
import torch.nn as nn
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3)
def forward(self, x):
return self.conv1(x)
- 训练循环
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(epochs):
for x, y in dataloader:
pred = model(x)
loss = F.cross_entropy(pred, y)
loss.backward()
optimizer.step()
- TorchScript:模型序列化工具
- TorchServe:生产级模型服务
- PyTorch Lightning:高级训练抽象
- TorchText/TorchVision:领域专用库
- ONNX支持:跨框架模型导出
| 版本 | 重要更新 |
|---|
| 1.0 (2018) | 合并Caffe2,支持生产部署 |
| 1.7 (2020) | 添加FFT和稀疏张量支持 |
| 2.0 (2022) | 编译模式(torch.compile) |