当前位置: 首页> 教育> 幼教 > torch.compile模型编译加速

torch.compile模型编译加速

时间:2025/7/11 19:57:10来源:https://blog.csdn.net/weixin_40777649/article/details/140441472 浏览次数:0次

一、定义

  1. 定义
  2. 接口介绍
  3. 案例

二、实现

  1. 定义

    1. torch.compile 是加速 PyTorch 代码的最新方法! torch.compile 通过 JIT 将 PyTorch 代码编译成优化的内核,使 PyTorch 代码运行得更快,大部分过程仅需修改一行代码。
    2. torch.compile 的一个重要组件就是 TorchDynamo。TorchDynamo 负责将任意 Python 代码即时编译成 FX Graph(计算图),然后可以进一步优化。TorchDynamo 通过在运行时分析 Python 字节码并检测对 PyTorch 操作的调用来提取 FX Graph。
    3. torch.compile 的另一个重要组件 TorchInductor 会将 FX Graph 进一步编译成优化的内核。TorchDynamo 允许使用不同的后端,所以为了检查 TorchDynamo 输出的 FX Graph,可以创建一个自定义后端来输出 FX Graph 并简单地返回 Graph 未优化的前向内容。
    4. 允许自定义函数
      开始编译的时候需要耗费大量的时间,即第一次请求,时间较长。
      5. 详情见: https://pytorch.org/docs/stable/torch.compiler.html
      https://pytorch.org/get-started/pytorch-2.0/
  2. 接口介绍

modoel_compile = torch.compile(model, mode="reduce-overhead")
(默认)default: 适合加速大模型,编译速度快且无需额外存储空间
reduce-overhead:适合加速小模型,需要额外存储空间
max-autotune:编译速度非常耗时,但提供最快的加速
  1. 案例
import torch
def foo(x, y):a = torch.sin(x)b = torch.cos(x)return a + b
opt_foo1 = torch.compile(foo)
print(opt_foo1(torch.randn(10, 10), torch.randn(10, 10)))
#方式二
@torch.compile
def opt_foo2(x, y):a = torch.sin(x)b = torch.cos(x)return a + b
print(opt_foo2(torch.randn(10, 10), torch.randn(10, 10)))
方式三
class MyModule(torch.nn.Module):def __init__(self):super().__init__()self.lin = torch.nn.Linear(100, 10)def forward(self, x):return torch.nn.functional.relu(self.lin(x))
mod = MyModule()
opt_mod = torch.compile(mod)
print(opt_mod(torch.randn(10, 100)))

训练

import torch
import torchvision.models as modelsmodel = models.resnet18().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
compiled_model = torch.compile(model)x = torch.randn(16, 3, 224, 224).cuda()
optimizer.zero_grad()
out = compiled_model(x)
out.sum().backward()
optimizer.step()

保存:

torch.save(optimized_model.state_dict(), "foo.pt")
# both these lines of code do the same thing
torch.save(model.state_dict(), "foo.pt")

推理:

# API Not Final
exported_model = torch._dynamo.export(model, input)
torch.save(exported_model, "foo.pt")
关键字:torch.compile模型编译加速

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

责任编辑: