torch.utils.checkpoint.checkpoint
是 PyTorch 提供的一种内存优化工具,用于在计算图的反向传播过程中节省显存。它通过重新计算某些前向传播的部分,减少了保存中间激活值所需的显存,特别适用于深度模型,如 Transformer 等层数较多的网络。
主要原理
在标准的反向传播中,前向传播过程中每一层的中间激活值(activation)会被保留,供后续反向传播使用。但在使用 checkpoint
时,某些层的激活值不会在前向传播时保存,而是在反向传播时通过重新计算这些层的前向结果来获得。
这样可以节省大量内存,但代价是增加了计算量,因为反向传播时需要重新计算部分前向传播。
使用方法
示例代码1
import torch
from torch.utils.checkpoint import checkpointdef forward_pass(x):# 假设是模型的一部分return x * x + 2 * x + 1# 在调用时使用 checkpoint 包裹住前向传播的函数
input_tensor = torch.randn(3, requires_grad=True)
output = checkpoint(forward_pass, input_tensor)
output.backward()print(input_tensor.grad) # 查看梯度
示例代码2
import torch
from torch import nn
from torch.utils.checkpoint import checkpoint