CANN/ops-nn:Keras动量优化算子

📅 2026/7/4 7:20:09
CANN/ops-nn:Keras动量优化算子
ApplyKerasMomentum【免费下载链接】ops-nn本项目是CANN提供的神经网络类计算算子库实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-nn产品支持情况产品是否支持Ascend 950PR/Ascend 950DT√Atlas A3 训练系列产品/Atlas A3 推理系列产品√Atlas A2 训练系列产品/Atlas A2 推理系列产品√Atlas 200I/500 A2 推理产品×Atlas 推理系列产品√Atlas 训练系列产品√功能说明算子功能执行Keras Momentum优化器的单步参数更新。根据动量系数momentum、学习率lr和梯度grad更新动量累积量accum并按标准模式或Nesterov模式原地更新权重参数varinplace 语义。对标TensorFlow中tf.raw_ops.ResourceApplyKerasMomentum接口的计算语义。计算公式$$ \begin{aligned} accum_{new} momentum \cdot accum - lr \cdot grad \end{aligned} $$标准模式use_nesterov 0$$ var_{new} var accum_{new} $$Nesterov模式use_nesterov 1$$ var_{new} var momentum \cdot accum_{new} - lr \cdot grad $$其中momentum为动量系数lr为学习率grad为当前梯度accum为动量累积量var为待更新的权重参数。参数说明参数名输入/输出/属性描述数据类型数据格式var输入 / 输出 (inplace)待更新的权重参数对应公式中的var。Kernel内inplace更新GE IR输出var_out与输入var共享Device内存。FLOAT16、FLOAT、BFLOAT16NDaccum输入 (inplace 更新)动量累积量对应公式中的accum。shape/dtype必须与var一致Kernel内显式写回输入GM地址。FLOAT16、FLOAT、BFLOAT16NDgrad输入当前梯度Tensor对应公式中的grad。shape/dtype必须与var一致。FLOAT16、FLOAT、BFLOAT16NDlr输入学习率对应公式中的lr。shape{1}的1元素scalar Tensordtype为FLOAT。FLOATNDmomentum输入动量系数对应公式中的momentum。shape{1}的1元素scalar Tensordtype为FLOAT。FLOATNDuse_nesterov输入是否使用Nesterov动量对应公式中的use_nesterov。shape{1}的1元素scalar Tensordtype为FLOAT0.0标准模式1.0Nesterov模式。FLOATNDvar输出更新后的var Tensor与输入var共享Device内存inplace。FLOAT16、FLOAT、BFLOAT16ND注accum 为 in-place 更新与 TensorFlow ResourceApplyKerasMomentum 一致不作为显式输出返回计算结果直接写回输入地址。约束说明var、accum、grad必须具有相同的数据类型和形状。lr、momentum、use_nesterov为标量Tensorshape{1}数据类型为FLOAT。调用说明调用方式调用样例说明图模式test_geir_apply_keras_momentum通过 算子IR 构图方式调用 ApplyKerasMomentum 算子。【免费下载链接】ops-nn本项目是CANN提供的神经网络类计算算子库实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-nn创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考