Conv3DBackpropInput使用说明【免费下载链接】asc-devkit本项目是CANN 推出的昇腾AI处理器专用的算子程序开发语言原生支持C和C标准规范主要由类库和语言扩展层构成提供多层级API满足多维场景算子开发诉求。项目地址: https://gitcode.com/cann/asc-devkitAscend C提供一组Conv3DBackpropInput高阶API便于用户快速实现卷积的反向运算求解反向传播的误差。转置卷积Conv3DTranspose与Conv3DBackpropInput具有相同的数学过程因此用户也可以使用Conv3DBackpropInput高阶API实现转置卷积算子。卷积的正反向传播如图1卷积层的前后向传播示意图反向传播误差计算如图2 反向传播误差计算示意图。Conv3DBackpropInput的计算公式为∂L/∂Y为卷积正向损失函数对输出Y的梯度GradOutput作为求反向传播误差∂L/∂X的输入。W为卷积正向Weight权重即矩阵核Kernel也是滤波器Filter作为求反向传播误差∂L/∂X的输入WT表示W的转置。∂L/∂X为特征矩阵的反向传播误差GradInput。图1卷积层的前后向传播示意图图2反向传播误差计算示意图Kernel侧实现Conv3DBackpropInput求解反向传播误差运算的步骤概括为创建Conv3DBackpropInput对象。初始化操作。设置卷积的输出反向GradOutput、卷积的输入Weight。完成卷积反向操作。结束卷积反向操作。[!NOTE]说明 下文中提及的M轴方向即为GradOutput矩阵纵向K轴方向即为GradOutput矩阵横向或Weight矩阵纵向N轴方向即为Weight矩阵横向。使用Conv3DBackpropInput高阶API求解反向传播误差运算的具体步骤如下创建Conv3DBackpropInput对象。#include lib/conv_backprop/conv3d_bp_input_api.h using weightDxType ConvBackpropApi::ConvTypeConvCommonApi::TPosition::GM, ConvCommonApi::ConvFormat::FRACTAL_Z_3D, weightType; using inputSizeDxType ConvBackpropApi::ConvTypeConvCommonApi::TPosition::GM, ConvCommonApi::ConvFormat::ND, int32_t; using gradOutputDxType ConvBackpropApi::ConvTypeConvCommonApi::TPosition::GM, ConvCommonApi::ConvFormat::NDC1HWC0, gradOutputType; using gradInputDxType ConvBackpropApi::ConvTypeConvCommonApi::TPosition::GM, ConvCommonApi::ConvFormat::NCDHW, gradInputType; ConvBackpropApi::Conv3DBackpropInputweightDxType, inputSizeDxType, gradOutputDxType, gradInputDxType gradInput_;创建对象时需要传入权重矩阵Weight、卷积正向特征矩阵Input的shape信息InputSize、GradOutput和GradInput的参数类型信息类型信息通过ConvType来定义包括内存逻辑位置、数据格式、数据类型。template TPosition POSITION, ConvFormat FORMAT, typename T struct ConvType { constexpr static TPosition pos POSITION; // Convolution输入或输出的逻辑位置 constexpr static ConvFormat format FORMAT; // Convolution输入或输出的数据格式 using Type T; // Convolution输入或输出的数据类型 };下面简要介绍在创建对象时使用到的相关数据结构开发者可选择性地了解这些内容。用于创建Conv3DBackpropInput对象的数据结构定义如下using Conv3DBackpropInput Conv3DBpInputIntf Conv3DBpInputCfgWEIGHT_TYPE, INPUT_TYPE, GRAD_OUTPUT_TYPE, GRAD_INPUT_TYPE, CONV3D_CFG_DEFAULT, Conv3DBpInputImpl;其中Conv3DBpInputIntf、Conv3DBpInputCfg数据结构定义如下template class Config_, template typename, class class Impl struct Conv3DBpInputIntf {}template class WEIGHT_TYPE, class INPUT_TYPE, class GRAD_OUTPUT_TYPE, class GRAD_INPUT_TYPE, const Conv3dConfig CONV3D_CONFIG CONV3D_CFG_DEFAULT struct Conv3DBpInputCfg : public ConvBpContextWEIGHT_TYPE, INPUT_TYPE, GRAD_OUTPUT_TYPE, GRAD_INPUT_TYPE {}表1ConvType说明| 参数 | 说明 | | --- | --- | | POSITION | 内存逻辑位置。Weight矩阵可设置为TPosition::GM。GradOutput矩阵可设置为TPosition::GM。InputSize可设置为TPosition::GM。GradInput矩阵可设置为TPosition::GM。 | | ConvFormat | 数据格式。Weight矩阵可设置为ConvFormat::FRACTAL_Z_3D。GradOutput矩阵可设置为ConvFormat::NDC1HWC0。InputSize矩阵可设置为ConvFormat::ND。GradInput矩阵可设置为ConvFormat::NDC1HWC0。 | | TYPE | 数据类型。Weight矩阵可设置为half、bfloat16_t。GradOutput矩阵可设置为half、bfloat16_t。InputSize矩阵可设置为int32_t。GradInput矩阵可设置为half、bfloat16_t。注意GradOutput矩阵和Weight矩阵数据类型需要一致具体数据类型组合关系请参考表2。 |表2Conv3DBackpropInput输入输出数据类型的组合说明WeightGradOutputInputSizeGradInput支持平台halfhalfint32_thalfAtlas A3 训练系列产品/Atlas A3 推理系列产品Atlas A2 训练系列产品/Atlas A2 推理系列产品bfloat16_tbfloat16_tint32_tbfloat16_tAtlas A3 训练系列产品/Atlas A3 推理系列产品Atlas A2 训练系列产品/Atlas A2 推理系列产品初始化操作。// 注册后进行初始化 ConvBackpropApi::Conv3DBackpropInputweightDxType, inputSizeDxType, gradOutputDxType, gradInputDxType gradInput_; gradInput_.Init((tilingData-conv3DDxTiling));设置3D卷积的输出反向GradOutput、3D卷积的输入Weight。gradInput_.SetSingleShape(singleShapeM_, singleShapeK_, singleShapeN_); // 设置单核计算的形状 gradInput_.SetStartPosition(dinStartIdx_, curHoStartIdx_); // 设置单核上gradOutput载入的起始位置 gradInput_.SetGradOutput(gradOutputGm_[offsetA_]); gradInput_.SetWeight(weightGm_[offsetB_]);完成卷积反向操作。调用Iterate完成单次迭代计算叠加while循环完成单核全量数据的计算。Iterate方式可以自行控制迭代次数完成所需数据量的计算。while (gradInput_.Iterate()) { gradInput_.GetTensorC(gradInputGm_[offsetC_]); }结束卷积反向操作。gradInput_.End();需要包含的头文件#include lib/conv_backprop/conv3d_bp_input_api.h【免费下载链接】asc-devkit本项目是CANN 推出的昇腾AI处理器专用的算子程序开发语言原生支持C和C标准规范主要由类库和语言扩展层构成提供多层级API满足多维场景算子开发诉求。项目地址: https://gitcode.com/cann/asc-devkit创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考