CANN算子库torch_extension开发规范

📅 2026/6/17 7:18:19
CANN算子库torch_extension开发规范
torch_extension开发规范【免费下载链接】ops-transformer本项目是CANN提供的transformer类大模型算子库实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-transformer本文档约定cann_ops_transformertorch_extension新增/修改算子api时的目录组织、命名、各层实现以及文档编写规范。开发者新增算子前请先通读本规范并参考已有算子如flash_attn作为模板。cann_ops_transformer通过PyTorch的JITtorch.utils.cpp_extension.load在首次调用时即时编译C kernel wrapper把PyTorch的函数接口桥接到CANN的aclnn接口同时通过GE Converter支持torchair图模式。一个完整的算子api通常由「Python前端、C后端、torchair图模式Converter、文档」四部分组成。1. 算子 api 目录组织规范总体开发新增算子api涉及新增与修改以下文件以算子${op_api}为例├── torch_extension ├── cann_ops_transformer ├── __init__.py # 包根入口对外导出算子 api 接口新增 import ├── op_builder │ ├── builder.py # OpBuilder 基类统一管理 JIT 编译、schema/meta 注册一般无需修改 ├── common │ ├── inc │ │ ├── aclnn_common.h # ACLNN_CMD 宏、类型转换等公共能力一般无需修改 │ │ ├── hccl_common.h # 通信类算子公共能力 ├── ops │ ├── csrc │ │ ├── ${op_api}.cpp # 算子 api 的 C 实现调用 aclnn 接口 │ ├── graph_convert │ │ ├── graph_convert_${op_api}.py # 算子 api 的 torchair 图模式GE Converter实现 │ ├── __init__.py # 对外导出新增算子 api新增 import │ └── ${op_api}.py # 算子 api 的 Python 前端实现OpBuilder、schema、meta、对外函数 ├── docs ├── torch_extension_guidelines.md # 本开发规范 ├── zh ├── ${op_api}.md # 算子 api 的中文文档新增一个算子api的标准动作清单以flash_attn为例在ops/csrc/flash_attn.cpp中实现C kernel wrapper调用ACLNN_CMD拉起aclnn接口在ops/flash_attn.py中编写OpBuilder子类定义sources/schema/register_meta注册dispatcher实现并提供对外的Python函数在ops/graph_convert/graph_convert_flash_attn.py中编写图模式Converter若需支持图模式在ops/__init__.py中导出新增的对外接口在包根cann_ops_transformer/__init__.py中import导出对外接口使用户可直接从包根访问在docs/zh/flash_attn.md中补充算子文档。新增文件请放在cann_ops_transformer包下import路径统一以cann_ops_transformer为根。2. 命名规范2.1 API 命名一个算子从schema注册到对外导出涉及多个层级的命名需保持一致且各司其职。对外api接口及算子名一律不带npu_前缀直接采用算子语义的小写蛇形名如flash_attn层级命名约定示例Library名DEF域固定为cann_ops_transformerAS_LIBRARY Library(cann_ops_transformer, DEF)schema算子名 / aten注册名算子语义的小写蛇形名不带npu_前缀flash_attnC wrapper函数名与schema算子名一致置于namespace op_api内op_api::flash_attnPYBIND11_MODULE导出名与schema算子名一致m.def(flash_attn, flash_attn, flash_attn);Meta实现函数名schema算子名 _meta后缀flash_attn_metaPrivateUse1 dispatcher函数名下划线前缀 schema算子名_flash_attnOpBuilder子类名算子名的大驼峰 OpBuilder后缀内部专用可加_前缀FlashAttnOpBuilder、_FlashAttnOpBuilder对外Python接口名用户直接调用的函数名体现使用语义不带npu_前缀flash_attn图模式GE op函数名与GE算子op_type一致的大驼峰FlashAttentionScore图模式Converter函数名convert_ schema算子名convert_flash_attn命名要点不带npu_前缀对外算子名与api接口统一使用算子语义名小写蛇形不加npu_等后端前缀schema名、C函数名、pybind导出名三者必须与该名字完全一致否则JIT编译产物无法被正确调用。接口名体现语义对外函数名应贴近业务语义。无论是纯透传aclnn接口的算子如flash_attn还是封装了结构体构造、参数整理等额外逻辑的接口均采用语义化命名。aclnn接口名独立底层aclnn接口沿用CANN命名大驼峰如aclnnFlashAttentionScore与对外算子名解耦C wrapper内通过ACLNN_CMD(aclnnFlashAttentionScore, ...)调用。版本后缀同一算子的不同迭代版本以_v2、_v3等后缀区分schema名、文件名、Converter名需同步带上版本后缀如flash_attn_v2、graph_convert_flash_attn_v2.py。辅助/工具接口与主算子配套的工具函数采用动宾语义命名如get_flash_attn_workspace_size。2.2 文件命名统一使用小写蛇形命名snake_case单词以_连接禁止使用大写、驼峰或连字符。同一算子的各层文件主名保持一致仅靠目录和前缀区分职责Python前端ops/${op_api}.py如flash_attn.pyC后端ops/csrc/${op_api}.cpp主名与Python前端一致如flash_attn.cpp图模式ops/graph_convert/graph_convert_${op_api}.py统一加graph_convert_前缀如graph_convert_flash_attn.py文档docs/zh/${op_api}.md如flash_attn.md。文件主名应与该文件主要导出的算子语义对应带版本的算子文件名需带版本后缀如flash_attn_v2.cpp。公共头文件放在common/inc/下按能力域命名如aclnn_common.h、hccl_common.h。2.3 标识符命名Python 标识符函数/变量/参数小写蛇形snake_case如head_num、scale_value、input_layout。类名大驼峰PascalCase如OpBuilder、FlashAttnOpBuilder。模块级常量全大写蛇形UPPER_SNAKE_CASE如AS_LIBRARY、ASCEND_HOME_PATH、TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP。模块内部私有符号以单下划线_前缀标识如_flash_attn_op_builder、_op_module、_flash_attn、_TORCHAIR_AVAILABLE。类型注解对外接口与关键内部函数应带类型注解from typing import Optional, Tuple, List可选参数统一用Optional[...]例如def flash_attn( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, atten_mask: Optional[torch.Tensor] None, scale_value: float 1.0, head_num: int 1, input_layout: str BSH, ) - Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:参数命名一致性同一算子在schema、meta、dispatcher、对外函数、Converter中的同义参数应使用相同的名字如head_num、scale_value、input_layout避免在不同层出现不一致写法。C 标识符函数/局部变量/参数小写蛇形如head_num、scale_value、input_layout_ptr。命名空间算子实现统一置于namespace op_api内。常量const/constexpr常量使用全大写蛇形或大驼峰如const int DIM_THREE 3;、kATenScalarTypeToAclDataTypeTable。类型别名/结构体大驼峰如TensorWrapper、TensorListWrapper。入参类型约定必选Tensor用const at::Tensor 可选Tensor用const c10::optionalat::Tensor Tensor列表用const std::vectorat::Tensor 可选列表用const c10::optionalstd::vectorat::Tensor 整型属性用int64_t可选整型属性用c10::optionalint64_t浮点属性用double字符串属性用std::string。Schema 标识符算子签名参数名采用小写蛇形与Python/C层一致。用*分隔位置参数与关键字参数*之前为必选的位置参数之后为可选的关键字参数带默认值。可选参数以?标注并给出默认值如Tensor? atten_maskNone、int? head_num1列表用Tensor[]可选列表用Tensor[]?。多输出用元组表示如- (Tensor, Tensor, Tensor, Tensor)。以flash_attn为例的schemaflash_attn(Tensor query, Tensor key, Tensor value, *, Tensor? atten_maskNone, float scale_value1.0, int head_num1, str input_layoutBSH) - (Tensor, Tensor, Tensor, Tensor)3. 各层实现规范3.1 C 后端ops/csrc/${op_api}.cpp负责把PyTorch张量桥接到aclnn C-API规范要点文件头部包含#include torch/extension.h与#include aclnn_common.h实现置于namespace op_api。函数签名与schema严格对应必选/可选参数类型按2.3 C入参类型约定选择。入参校验使用TORCH_CHECK(cond, msg...)校验shape、dtype、维度、取值范围等错误信息要可读且包含实际值例如TORCH_CHECK((head_num 0), The head_num should be greater than 0, current is: , head_num); TORCH_CHECK((query.scalar_type() key.scalar_type()), query and key should have the same dtype.);设置DeviceGuard关键在申请输出张量之前必须先根据输入张量设置c10::OptionalDeviceGuard把当前NPU设备切到输入张量所在设备并用{}作用域把「DeviceGuard 输出申请」包在一起否则非默认卡调用时输出张量会落到错误设备导致device不一致at::Tensor attention_out{nullptr}; { auto local_device c10::Device(query.device()); const c10::OptionalDeviceGuard device_guard(local_device); attention_out at::empty(query.sizes(), query.options()); // ... 其余输出 ... }输出张量手动申请在DeviceGuard生效的作用域内按meta推导的shape/dtype用at::empty(...)申请输出标准PyTorch实践dtype通过query.options().dtype(...)指定。拉起kernel使用ACLNN_CMD(aclnn接口名, 入参..., 出参...)宏调用aclnn接口如ACLNN_CMD(aclnnFlashAttentionScore, ...)入参顺序需与aclnn接口定义一致该宏自动完成类型转换、workspace申请、stream下发与资源释放。导出绑定通过PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)将C函数绑定为与schema同名的Python接口m.def(flash_attn, flash_attn, flash_attn);。魔数如维度数3、默认dtype枚举应以具名常量表达避免裸写字面量。C wrapper的典型骨架std::tupleat::Tensor, at::Tensor, at::Tensor, at::Tensor flash_attn( const at::Tensor query, const at::Tensor key, const at::Tensor value, const c10::optionalat::Tensor atten_mask, double scale_value, int64_t head_num, std::string input_layout) { // 3. 入参校验 TORCH_CHECK((head_num 0), The head_num should be greater than 0, current is: , head_num); at::Tensor attention_out{nullptr}; { // 4. DeviceGuard必须在申请输出之前作用域包住输出申请 auto local_device c10::Device(query.device()); const c10::OptionalDeviceGuard device_guard(local_device); // 5. 申请输出张量 attention_out at::empty(query.sizes(), query.options()); // ... 其余输出 ... } // 6. 拉起 aclnn kernel ACLNN_CMD(aclnnFlashAttentionScore, query, key, value, atten_mask, scale_value, head_num, input_layout.data(), /* outputs */ attention_out); return std::make_tuple(/* ... */); }3.2 Python 前端ops/${op_api}.py负责JIT编译管理、schema/meta注册与对外接口封装OpBuilder子类继承OpBuilder在__init__中以super().__init__(schema算子名)传入算子名并实现三个抽象方法sources()返回相对cann_ops_transformer包根的C源文件路径列表如[ops/csrc/flash_attn.cpp]schema()返回算子schema字符串见2.3 Schema标识符register_meta()用impl(AS_LIBRARY, self.name, Meta)注册Meta实现仅做shape/dtype推导不触碰真实NPU计算FakeTensor/图模式必需。Meta中同样可用torch._check(...)做约束校验。实例化与编译模块加载时实例化builder并load()触发编译_flash_attn_op_builder _FlashAttnOpBuilder() _op_module _flash_attn_op_builder.load()PrivateUse1 dispatcher用impl(AS_LIBRARY, builder.name, PrivateUse1)注册NPU后端实现函数体透传到编译产物_op_module.算子名(...)如_op_module.flash_attn(...)。PrivateUse1是PyTorch为自定义NPU后端预留的dispatch key。对外接口提供面向用户的函数flash_attn(...)负责参数整理、默认值处理等最终调用dispatcher实现。对外api必须书写注释docstring每个对外导出的接口都要有docstring至少覆盖「功能说明、各参数含义/shape/dtype/取值范围、返回值说明」必要时给出简短调用示例。docstring内容应与docs/zh/${op_api}.md保持一致便于IDE提示与help()查看。例如def flash_attn( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, atten_mask: Optional[torch.Tensor] None, scale_value: float 1.0, head_num: int 1, input_layout: str BSH, ) - Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: FlashAttention 前向计算封装 aclnnFlashAttentionScore。 Args: query (Tensor): 查询张量shape 由 input_layout 决定如 BSHdtype 支持 float16/bfloat16。 key (Tensor): 键张量dtype 与 query 一致。 value (Tensor): 值张量dtype 与 query 一致。 atten_mask (Tensor, optional): 注意力掩码默认 None 表示不使用。 scale_value (float): 缩放系数默认 1.0。 head_num (int): 单卡 head 数即 query 的 N 轴长度默认 1。 input_layout (str): 输入数据排布支持 BSH/BNSD 等默认 BSH。 Returns: Tuple[Tensor, Tensor, Tensor, Tensor]: softmax_max、softmax_sum、softmax_out、attention_out。 Meta实现、dispatcher、对外函数三者的参数顺序与默认值必须与schema一致。3.3 图模式 Converterops/graph_convert/graph_convert_${op_api}.py负责在torchair图模式GE下把aten算子转换为GE节点可选依赖保护torchair相关import统一包在try/except ImportError中用_TORCHAIR_AVAILABLE标志位控制避免在无torchair环境下导入失败。GE op函数定义与op_type同名的大驼峰函数如FlashAttentionScore通过docstring写明REG_OP的IR定义INPUT/DYNAMIC_INPUT/OPTIONAL_INPUT/OUTPUT/ATTR等并组织inputs/attrs/outputs后调用ge_op(...)IR通过IrDef(...)链式声明。Converter注册用register_fx_node_ge_converter(torch.ops.cann_ops_transformer.flash_attn.default)装饰convert_flash_attn函数其参数顺序与schema完全一致函数体调用上面的GE op函数。在ops/__init__.py中导出Converter如convert_flash_attn确保注册逻辑被执行。3.4 对外导出ops/__init__.py与包根__init__.py对外导出分两级两处都需新增importops/__init__.py子包层每个新增算子的对外接口与Converter都需在此显式import导出导入即触发schema/meta/converter注册from .flash_attn import flash_attn from .graph_convert.graph_convert_flash_attn import convert_flash_attn同一算子若导出多个符号使用括号分组的多行import。cann_ops_transformer/__init__.py包根层除已有的from . import ops触发注册外还需把对外接口import到包根命名空间使用户可直接通过cann_ops_transformer.接口名访问而不必写完整的cann_ops_transformer.ops.接口名路径from . import ops from .ops import flash_attn __all__ [flash_attn]建议在包根维护__all__显式列出对外导出的接口名便于管理可见接口集合。完成两级导出后用户既可from cann_ops_transformer import flash_attn也可cann_ops_transformer.ops.flash_attn(...)调用导入主包即完成schema/meta/converter注册。4. 文档规范docs/zh/${op_api}.md每个对外算子api需配套一份中文文档建议章节顺序与已有算子文档如flash_attn.md对齐标题算子名特殊字符如_需转义为\_。产品支持情况表格列出支持的产品形态如Ascend 950PR/Ascend 950DT及是否支持。功能说明API功能概述 计算公式数学表达用LaTeX并说明各符号与参数的对应关系。函数原型代码块给出完整函数签名含默认值与*分隔。参数说明逐个参数说明「必选/可选、语义、shape、dtype、数据格式如$ND$、是否支持非连续Tensor、取值范围/约束」可选参数标注默认值与「暂不支持」说明。输出说明逐个输出说明shape、dtype、格式等。约束说明分类列出参数一致性约束、shape/取值范围约束、量化场景约束等通信类算子还需列出通信域约束。配套接口说明若算子需与其他接口配套使用补充其原型、参数与输出说明。调用示例给出单算子模式必要时含多卡/通信初始化的完整可运行示例图模式若暂不支持需明确标注「图模式调用暂不支持」。5. 编码通用约束许可证头所有新增源文件.py/.cpp/.h必须包含Huawei版权与CANN Open Software License Agreement Version 2.0许可证头年份填当年。Python/脚本用#注释C用//或/* */。接口注释对外api接口必须书写docstring功能、参数、返回值见3.2C wrapper关键逻辑校验、DeviceGuard、aclnn调用也应有简要注释。C层DeviceGuard关键调用aclnn的C wrapper中必须在申请输出张量之前用c10::OptionalDeviceGuard构造自c10::Device(输入张量.device())把设备切到输入张量所在设备详见3.1。参数校验前置Python侧用torch._check(cond, lambda: f...{var}...)C侧用TORCH_CHECK(cond, msg...)错误信息需包含变量实际值便于定位。错误码Python侧可结合torch_npu.utils._error_code的ErrCode/ops_error输出规范错误码如f... {ops_error(ErrCode.VALUE)}.。避免魔数维度数、dtype枚举值等以具名常量表达并在文档/注释中说明枚举含义如23 → float8_e5m2、24 → float8_e4m3fn。公共能力复用类型转换、ACLNN_CMD、通信域处理等优先复用common/inc下的公共头不在各算子中重复实现。一致性自检提交前确认schema、C wrapper、Meta、dispatcher、对外函数、Converter、文档七处的算子名、参数名、参数顺序、默认值保持一致。【免费下载链接】ops-transformer本项目是CANN提供的transformer类大模型算子库实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-transformer创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考