CANN社区ScatterND算子设计文档

📅 2026/7/4 7:18:06
CANN社区ScatterND算子设计文档
【社区任务】ScatterND 算子设计文档【免费下载链接】cann-ops-competitions本仓库用于 CANN 开源社区各类竞赛、开源课题、社区任务等课题发布、开发者作品提交和展示。项目地址: https://gitcode.com/cann/cann-ops-competitions一、需求背景1.1 需求来源通过社区任务完成 ScatterND 算子在 ops-nn 开源仓中的实现与适配。1.2 背景介绍1.2.1 ScatterND 算子实现优化ScatterND 算子用于将 updates 张量中的值根据 indices 索引分散更新到 input 张量的指定位置。TBE 算子源码路径本次任务参考实现该文件为 ScatterNd TBE legacy 算子适配脚本基于 Tik DSL 实现。原始 CANN toolkit 内置路径通常位于/home/developer/Ascend/cann-9.0.0-beta.2/opp/built-in/op_impl/ai_core/tbe/impl/ops_legacy/dynamic/scatter_nd.pykernel 实现/home/developer/Ascend/ascend-toolkit/latest/opp/built-in/op_impl/ai_core/tbe/impl/dynamic/算子信息库路径/home/developer/Ascend/ascend-toolkit/latest/opp/built-in/op_impl/ai_core/tbe/config/ascend910b算子原型路径/home/developer/Ascend/ascend-toolkit/latest/opp/built-in/op_proto/inc/参考文件正确性验证aclnn 接口为aclnnScatterNd其输入为(input, indices, updates)输出为output。TBE 源码scatter_nd.py中算子注册名为ScatterNdregister_operator(ScatterNd)与 aclnn 接口名称及原型注册REG_OP(ScatterNd)一致入口函数scatter_nd(indices, x, shape, y, kernel_nameScatterNd, impl_modeNone)的参数语义为indices索引张量对应 aclnn 的indicesx更新张量updatesshape输出张量output的 shape由 input shape 推导y输出张量output因此该函数包含 3 个输入参数indices/x/shape和 1 个输出参数y与原型注册参数一致入口函数根据平台能力选择 Tik 或 DSL 实现路径若平台不支持tbe.dsl.vexpfloat32走 Tik 路径scatter_nd_tik否则走 DSL 路径scatter_nd_dslTik 路径的ScatterNd类通过 Tik 直接 CCE 编程包含 16 种tiling_mode分支分为 atomicmode 1-5, 16和非 atomicmode 6-15两大类DSL 路径调用tbe.scatter_nd(var_tensor, indices_tensor, updates_tensor, reduction)完成计算语义定义参数语义与 aclnn 的input/indices/updates对应。因此上述 TBE 源码文件为本算子的正确参考实现。1.2.2 ScatterND算子现状分析1.2.2.1 支持的数据类型和数据格式根据算子信息库ScatterND 的 TBE 版本在 Atlas A2 平台支持的数据类型为 FLOAT16、FLOAT、INT32。算子信息库声明支持动态编译、动态 format、动态 rank、动态 shape输入输出 shape 能力为 all。 本算子最终支持 ND storage formatdata/updates/output dtype 支持 FLOAT16、FLOAT、INT32indices dtype 支持 INT32、INT64。项目本次任务范围当前实现输入input, indices, updatesinput, indices, updates输出outputoutputdtypeFLOAT16, FLOAT, INT32FLOAT16, FLOAT, INT32indices dtypeINT32, INT64INT32, INT64storage formatNDNDdynamicdynamic shape/rank/format运行时推导 shape1.2.2.2 算子实现描述TBE 适配脚本的主要逻辑如下scatter_nd入口函数根据平台能力选择 Tik 或 DSL 实现路径。scatter_nd_tikTik 路径创建 ScatterNd 类实例通过 Tik 直接构建多核并行计算程序。scatter_nd_dslDSL 路径通过 TBE DSL 的tbe.scatter_nd()接口完成计算由框架自动调度。Tik 实现核心流程ScatterNd.__init__()完成参数校验check_input_params、缓冲区大小计算、Tik Tensor/Scalar 声明。scatter_nd_compute_tiling()从 GM 读取 35 参数 tiling 到 UB解析 tiling_args根据tiling_mode选择分核策略。分为两大类执行路径atomic 模式tiling_mode 1-5, 16通过set_atomic_add直接将 updates 原子累加到 output GM。非 atomic 模式tiling_mode 6-15从 GM 读取 var → UB在 UB 中vec_addvar updates后写回 GM。每个模式根据 UB 容量、32B 对齐、数据量大小选择不同的搬运和计算策略。BuildCCE编译生成最终核函数。DSL 实现核心流程参数校验dtype 检查。classify对输入进行分类shape_util.variable_shape推导动态 shape。创建占位 Tensorvar_tensor、indices_tensor、updates_tensor、shape_tensor。调用tbe.scatter_nd(var_tensor, indices_tensor, updates_tensor, reduction)完成计算语义定义。tbe.auto_schedule自动调度tbe.build编译构建。核心计算语义如下# indices.shape (M, N), updates.shape (M, *input.shape[N:]) for idx in range(M): # 根据 indices 计算多维偏移量 outputIdx compute_offset(indices[idx], output_shape) output[outputIdx] updates[idx]1.2.2.3 算子实现流程图以下为 TBE 参考实现的完整流程图包含 DSL 与 Tik 两条路径以及 Tik 路径中 16 种tiling_mode的详细分支。![TBE 参考实现流程图](https://raw.gitcode.com/cann/cann-ops-competitions/raw/48be878aeec13265640688bf02d643fb416f68b5/04_tasks/01_community-task-2026/tasklist/03-18-ScatterNd/yuanfan/docs/![ScatterNdTensor.png](https:/raw.gitcode.com/user-images/assets/10130836/9f04043d-366d-43b3-9f93-e331db929be3/ScatterNdTensor.png ScatterNdTensor.png?utm_sourcegitcode_repo_files))二、需求分析2.1 外部组件依赖外部依赖使用位置作用TBE 框架TBE 适配脚本参数校验、调度生成、编译构建Tik 框架Tik 路径直接 CCE 编程多核并行、UB 管理、vec 指令TBE DSLDSL 路径tbe.scatter_nd计算语义auto_scheduleACL Runtime核函数设备内存管理和 kernel 启动2.2 内部适配模块模块文件作用TBE Tik 适配模块scatter_nd.py参数校验、Tik 多核并行构建、tiling 分发、编译调用TBE DSL 适配模块scatter_nd.pyDSL compute auto_schedule build 路径Tik 计算类ScatterNd同文件内Tik 程序主体包含 16 种 tiling_mode 分支2.3 需求模块设计2.3.1 算子原型名称类别dtypeshape/取值说明input输入ND: FLOAT16/FLOAT/INT32任意 shape输入张量indices输入ND: INT32, INT64(M, N)索引张量updates输入与 input 相同(M, *input.shape[N:])更新张量output输出与 input 相同与 input 相同输出张量2.3.2 算子相关约束最终支持范围与 TBE/原型范围保持一致。约束如下input和updates的 dtype 必须一致。indices的 dtype 必须为 INT32 或 INT64。indices和updates的 shape 必须完全一致。input的前 N 维必须与indices的 shape 匹配N len(indices.shape)。三、需求详细设计3.1 使能方式当前实现适配 TBE 调用框架。调用流程为scatter_nd(input, indices, updates)接口完成参数校验、输出 shape 推导、调度生成和核函数调用。3.2 需求总体设计3.2.1 host 侧设计host 侧在执行阶段解析输入 shape。TBE 参考实现中的关键参数与分核策略由tiling在运行前计算核函数通过tiling_gm传入 35 个 int64 参数包括tiling_mode执行模式选择1-16core_num实际使用的核数indice_step每个核负责的索引步长update_data_num每个索引对应的更新数据量indices_loop_num / indices_last_numindices 循环次数和尾数var_each_core_data / var_last_core_data每个核负责的 output 数据量var_each_core_set_zero_loop_num / var_last_core_set_zero_loop_numoutput 初始化清零参数。3.2.2 kernel 侧设计3.2.2.1 kernel 侧实现描述本次任务 kernel 侧设计与 TBE 参考实现保持一致采用 Tik 手写实现入口阶段scatter_nd根据平台能力选择scatter_nd_tik或scatter_nd_dsl。Tik 初始化阶段ScatterNd.__init__()校验参数计算 UB 容量、数据块大小、atomic 支持等并声明 Tik Tensor/Scalar。tiling 解析阶段scatter_nd_compute_tiling()从 GM 读取 tiling 参数解析后按tiling_mode分发。多核并行阶段每个 AI Core 执行init_ub_tensor()后调用traversing_indices(mode)处理 indices。索引计算阶段get_var_read_index(indices_ub_index)根据 indices 的每一行和var_offset_index_tiling计算 output 偏移。更新阶段atomic 模式move_indices选择 atomic 分支通过set_atomic_add将 updates 直接写回out_gm非 atomic 模式move_indices选择非 atomic 分支根据具体tiling_mode将out_gm中对应区域读取到var_ub与updates做vec_add后写回out_gm。traversing_var_single_core / traversing_var_mul_core等函数在源码中作为清零/初始化辅助函数存在用于特定场景。编译阶段scatter_nd_operator()调用BuildCCE生成核函数。核心计算公式outputIdx sum(indices[idx][k] * var_offset_index_tiling[k]) # k 0..indices_last_dim-1 output[outputIdx * update_data_num innerIdx] updates[idx * update_data_num innerIdx]3.2.2.2 AscendC 实现流程图3.2.2.3 实现流程图与 TBE 流程图存在的差异点和原因差异点TBE 参考实现E:\download\scantter_nd\tbe\scatter_nd.py本次任务设计与 TBE 一致原因实现路径同时存在 Tik 路径和 DSL 路径运行时根据平台能力选择与 TBE 一致保持最大兼容性和性能调度方式Tik 路径为显式手动分核调度tiling_mode 多核 for_range与 TBE 一致手动控制并行粒度和 UB 使用动态 shape由 TBE 框架 / tiling 参数在运行时推导与 TBE 一致实现形式由框架决定计算表达Tik 路径显式使用set_atomic_add或vec_add完成更新与 TBE 一致数学语义一致与硬件能力匹配初始化阶段非 atomic 模式下读取 out_gm 对应区域到 var_ub与 updates 做vec_add后写回traversing_var_single_core / traversing_var_mul_core等清零辅助函数在源码中定义用于特定场景与 TBE 一致与 TBE 参考源码保持一致3.3 支持硬件Atlas A2 训练系列产品。3.4 算子约束限制支持 dtypeFLOAT16、FLOAT、INT32。storage format 支持 ND且输入输出必须一致。indices dtype 支持 INT32、INT64。indices 和 updates 的 shape 必须完全一致。四、特性交叉分析特性当前实现动态 shapehost 侧根据运行时 shape 推导输出动态 format当前支持 NDdtypeND 支持 FLOAT16、FLOAT、INT32indices dtype支持 INT32、INT64五、可维可测分析5.1 精度标准/性能标准标准描述精度标准输出满足 AscendOpTest 默认精度阈值性能标准所有核参与计算场景下性能与 TBE 持平5.2 兼容性分析当前实现的参数顺序与 TBE 接口语义一致输入输出、属性列表、dtype/format 注册和 shape 推导均以 TBE 实现为基准。【免费下载链接】cann-ops-competitions本仓库用于 CANN 开源社区各类竞赛、开源课题、社区任务等课题发布、开发者作品提交和展示。项目地址: https://gitcode.com/cann/cann-ops-competitions创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考