From 7fa7b9d23b47eb4c6f55b19c24a5ebe84e757612 Mon Sep 17 00:00:00 2001 From: yangruoqi713 Date: Tue, 5 Jan 2021 10:33:01 +0800 Subject: [PATCH] [MSLITE][DEVELOP] add npu op fullconnection, reduce_mean; add npu testcases --- .../agent/npu/optimizer/npu_fusion_pass.cc | 3 + .../kernel/npu/convolution_base_npu.cc | 7 +- .../runtime/kernel/npu/convolution_base_npu.h | 3 +- .../kernel/npu/convolution_depthwise_npu.cc | 10 +- .../src/runtime/kernel/npu/convolution_npu.cc | 8 +- .../runtime/kernel/npu/deconvolution_npu.cc | 8 +- .../runtime/kernel/npu/fullconnection_npu.cc | 122 ++++++++++++++++++ .../runtime/kernel/npu/fullconnection_npu.h | 47 +++++++ .../lite/src/runtime/kernel/npu/reduce_npu.cc | 72 +++++++++++ .../lite/src/runtime/kernel/npu/reduce_npu.h | 45 +++++++ mindspore/lite/test/models_caffe.cfg | 1 + mindspore/lite/test/models_npu.cfg | 23 ++++ 12 files changed, 342 insertions(+), 7 deletions(-) create mode 100644 mindspore/lite/src/runtime/kernel/npu/fullconnection_npu.cc create mode 100644 mindspore/lite/src/runtime/kernel/npu/fullconnection_npu.h create mode 100644 mindspore/lite/src/runtime/kernel/npu/reduce_npu.cc create mode 100644 mindspore/lite/src/runtime/kernel/npu/reduce_npu.h diff --git a/mindspore/lite/src/runtime/agent/npu/optimizer/npu_fusion_pass.cc b/mindspore/lite/src/runtime/agent/npu/optimizer/npu_fusion_pass.cc index 13f76821c7..07d51f1450 100644 --- a/mindspore/lite/src/runtime/agent/npu/optimizer/npu_fusion_pass.cc +++ b/mindspore/lite/src/runtime/agent/npu/optimizer/npu_fusion_pass.cc @@ -21,6 +21,9 @@ namespace mindspore::lite { bool CheckFusion(kernel::LiteKernel *kernel) { + if (kernel->in_kernels().empty() || kernel->out_kernels().empty()) { + return false; + } auto pre_flag = std::all_of(kernel->in_kernels().begin(), kernel->in_kernels().end(), [](const kernel::LiteKernel *in_kernel) { return NPUPassUtils::IsNchw2Nhwc(in_kernel) && in_kernel->out_kernels().size() == 1; diff --git a/mindspore/lite/src/runtime/kernel/npu/convolution_base_npu.cc b/mindspore/lite/src/runtime/kernel/npu/convolution_base_npu.cc index 07e948ad85..282d732b88 100644 --- a/mindspore/lite/src/runtime/kernel/npu/convolution_base_npu.cc +++ b/mindspore/lite/src/runtime/kernel/npu/convolution_base_npu.cc @@ -34,7 +34,7 @@ ConvolutionBaseNPUKernel::~ConvolutionBaseNPUKernel() { } } -int ConvolutionBaseNPUKernel::InitWeightBiasConst(const std::vector &inputs) { +int ConvolutionBaseNPUKernel::InitWeightConst(const std::vector &inputs) { weight_ = new (std::nothrow) hiai::op::Const(name_ + "_w"); if (weight_ == nullptr) { MS_LOG(ERROR) << "New weight const failed."; @@ -61,7 +61,10 @@ int ConvolutionBaseNPUKernel::InitWeightBiasConst(const std::vectorset_attr_value(weight_tensor); free(nchw_data); + return RET_OK; +} +int ConvolutionBaseNPUKernel::InitBiasConst(const std::vector &inputs) { if (inputs.size() >= 3) { bias_ = new (std::nothrow) hiai::op::Const(name_ + "_b"); if (bias_ == nullptr) { @@ -88,7 +91,7 @@ int ConvolutionBaseNPUKernel::SetActivation(const ge::Operator *input, ActType a } else if (act_type == ActType_Relu6) { act_->set_attr_mode(14); } else { - MS_LOG(ERROR) << "Unsupport activation for convolution."; + MS_LOG(ERROR) << "Unsupport activation type for convolution."; return RET_ERROR; } return RET_OK; diff --git a/mindspore/lite/src/runtime/kernel/npu/convolution_base_npu.h b/mindspore/lite/src/runtime/kernel/npu/convolution_base_npu.h index a15163a8ca..31cce4fe99 100644 --- a/mindspore/lite/src/runtime/kernel/npu/convolution_base_npu.h +++ b/mindspore/lite/src/runtime/kernel/npu/convolution_base_npu.h @@ -32,7 +32,8 @@ class ConvolutionBaseNPUKernel : public NPUKernel { ~ConvolutionBaseNPUKernel() override; protected: - int InitWeightBiasConst(const std::vector &inputs); + int InitWeightConst(const std::vector &inputs); + int InitBiasConst(const std::vector &inputs); int SetActivation(const ge::Operator *input, ActType act_type); hiai::op::Activation *act_ = nullptr; hiai::op::Const *weight_ = nullptr; diff --git a/mindspore/lite/src/runtime/kernel/npu/convolution_depthwise_npu.cc b/mindspore/lite/src/runtime/kernel/npu/convolution_depthwise_npu.cc index 6334f9613f..206d9aca25 100644 --- a/mindspore/lite/src/runtime/kernel/npu/convolution_depthwise_npu.cc +++ b/mindspore/lite/src/runtime/kernel/npu/convolution_depthwise_npu.cc @@ -39,7 +39,7 @@ int ConvolutionDepthwiseNPUKernel::SetConvDwParam() { conv_dw_->set_attr_pad_mode(ge::AttrValue::STR{"VALID"}); conv_dw_->set_attr_pads(ge::AttrValue::LIST_INT({0, 0, 0, 0})); } else { - conv_dw_->set_attr_pad_mode(ge::AttrValue::STR{"SPECIFIC"}); + conv_dw_->set_attr_pad_mode(ge::AttrValue::STR{"VALID"}); conv_dw_->set_attr_pads( ge::AttrValue::LIST_INT({conv_param_->pad_u_, conv_param_->pad_d_, conv_param_->pad_l_, conv_param_->pad_r_})); } @@ -61,13 +61,19 @@ int ConvolutionDepthwiseNPUKernel::SetNPUInputs(const std::vectorset_input_filter(*weight_); + if (inputs.size() == 3) { + ret = InitBiasConst(inputs); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Set bias for convolution depthwise op " << name_ << " failed when running npu"; + return RET_ERROR; + } conv_dw_->set_input_bias(*bias_); } conv_dw_->set_input_x(*npu_inputs[0]); diff --git a/mindspore/lite/src/runtime/kernel/npu/convolution_npu.cc b/mindspore/lite/src/runtime/kernel/npu/convolution_npu.cc index e36bf75d61..4187f150b5 100644 --- a/mindspore/lite/src/runtime/kernel/npu/convolution_npu.cc +++ b/mindspore/lite/src/runtime/kernel/npu/convolution_npu.cc @@ -65,13 +65,19 @@ int ConvolutionNPUKernel::SetNPUInputs(const std::vector &inputs return RET_ERROR; } - ret = InitWeightBiasConst(inputs); + ret = InitWeightConst(inputs); if (ret != RET_OK) { MS_LOG(ERROR) << "Set weight and bias for convolution op " << name_ << " failed when running npu"; return RET_ERROR; } conv_->set_input_filter(*weight_); + if (inputs.size() == 3) { + ret = InitBiasConst(inputs); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Set bias for convolution op " << name_ << " failed when running npu"; + return RET_ERROR; + } conv_->set_input_bias(*bias_); } conv_->set_input_x(*npu_inputs[0]); diff --git a/mindspore/lite/src/runtime/kernel/npu/deconvolution_npu.cc b/mindspore/lite/src/runtime/kernel/npu/deconvolution_npu.cc index ac15301345..3a936f638e 100644 --- a/mindspore/lite/src/runtime/kernel/npu/deconvolution_npu.cc +++ b/mindspore/lite/src/runtime/kernel/npu/deconvolution_npu.cc @@ -65,13 +65,19 @@ int DeconvolutionNPUKernel::SetNPUInputs(const std::vector &inpu return RET_ERROR; } - ret = InitWeightBiasConst(inputs); + ret = InitWeightConst(inputs); if (ret != RET_OK) { MS_LOG(ERROR) << "Set weight and bias for deconvolution op " << name_ << " failed when running npu"; return RET_ERROR; } deconv_->set_input_filter(*weight_); + if (inputs.size() == 3) { + ret = InitBiasConst(inputs); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Set bias for deconvolution op " << name_ << " failed when running npu"; + return RET_ERROR; + } deconv_->set_input_bias(*bias_); } deconv_->set_input_x(*npu_inputs[0]); diff --git a/mindspore/lite/src/runtime/kernel/npu/fullconnection_npu.cc b/mindspore/lite/src/runtime/kernel/npu/fullconnection_npu.cc new file mode 100644 index 0000000000..2655fcb3b0 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/npu/fullconnection_npu.cc @@ -0,0 +1,122 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/npu/fullconnection_npu.h" +#include +#include "src/kernel_registry.h" +#include "src/runtime/agent/npu/npu_converter_utils.h" +using mindspore::kernel::KERNEL_ARCH::kNPU; +using mindspore::lite::KernelRegistrar; +using mindspore::schema::PrimitiveType_FullConnection; + +namespace mindspore::kernel { +int FullconnectionNPUKernel::IsSupport(const std::vector &inputs, + const std::vector &outputs, OpParameter *opParameter) { + return RET_OK; +} + +int FullconnectionNPUKernel::SetNPUInputs(const std::vector &inputs, + const std::vector &outputs, + const std::vector &npu_inputs) { + auto input_shape = inputs[0]->shape(); + reshape_ = new (std::nothrow) hiai::op::Reshape(name_ + "_reshape"); + if (reshape_ == nullptr) { + MS_LOG(ERROR) << "New reshape operator for fullconnection op " << name_ << " failed."; + return RET_ERROR; + } + reshape_->set_input_x(*npu_inputs[0]); + int col = 1; + for (int i = 1; i < input_shape.size(); i++) { + col *= input_shape[i]; + } + auto reshape_op = new (std::nothrow) hiai::op::Const(name_ + "_reshape_data"); + vector reshape_data = {input_shape[0], col}; + ge::TensorDesc reshape_tensor_desc(ge::Shape({2}), ge::FORMAT_NCHW, ge::DT_FLOAT); + ge::TensorPtr reshape_tensor = std::make_shared(reshape_tensor_desc); + reshape_tensor->SetData(reinterpret_cast(reshape_data.data()), 2 * sizeof(float)); + reshape_op->set_attr_value(reshape_tensor); + reshape_->set_input_shape(*reshape_op); + + fc_ = new (std::nothrow) hiai::op::MatMul(name_); + if (fc_ == nullptr) { + MS_LOG(ERROR) << "New matmul operator for fullconnection op " << name_ << " failed."; + return RET_ERROR; + } + fc_->set_input_x1(*reshape_); + + weight_ = new (std::nothrow) hiai::op::Const(name_ + "_w"); + if (weight_ == nullptr) { + MS_LOG(ERROR) << "New weight const failed."; + return RET_ERROR; + } + inputs[1]->set_format(schema::Format_NCHW); + auto weight_tensor = mindspore::lite::ConverterToNPUTensor(inputs[1]); + weight_->set_attr_value(weight_tensor); + inputs[1]->set_format(schema::Format_NHWC); + fc_->set_input_x2(*weight_).set_attr_transpose_x2(true); + + if (fc_param_->has_bias_) { + biasadd_ = new (std::nothrow) hiai::op::BiasAdd(name_ + "_biasadd"); + if (biasadd_ == nullptr) { + MS_LOG(ERROR) << "New biasadd operator for fullconnection op " << name_ << " failed."; + return RET_ERROR; + } + + auto ret = InitBiasConst(inputs); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Set bias for convolution op " << name_ << " failed when running npu"; + return RET_ERROR; + } + biasadd_->set_input_x(*fc_).set_input_bias(*bias_); + } + + if (fc_param_->act_type_ != ActType_No) { + auto ret = + biasadd_ == nullptr ? SetActivation(fc_, fc_param_->act_type_) : SetActivation(biasadd_, fc_param_->act_type_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "New activation npu operator for op " << name_ << " failed."; + return RET_ERROR; + } + } + return RET_OK; +} + +ge::Operator *mindspore::kernel::FullconnectionNPUKernel::GetNPUOp() { + if (fc_param_->act_type_ != ActType_No) { + return act_; + } + if (fc_param_->has_bias_) { + return biasadd_; + } + return fc_; +} + +FullconnectionNPUKernel::~FullconnectionNPUKernel() { + if (reshape_ != nullptr) { + delete reshape_; + reshape_ = nullptr; + } + if (fc_ != nullptr) { + delete fc_; + fc_ = nullptr; + } + if (biasadd_ != nullptr) { + delete biasadd_; + biasadd_ = nullptr; + } +} +REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_FullConnection, NPUKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/npu/fullconnection_npu.h b/mindspore/lite/src/runtime/kernel/npu/fullconnection_npu.h new file mode 100644 index 0000000000..280504a4db --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/npu/fullconnection_npu.h @@ -0,0 +1,47 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_FULLCONNECTION_NPU_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_FULLCONNECTION_NPU_H_ +#include +#include "src/runtime/kernel/npu/convolution_base_npu.h" +#include "include/graph/op/all_ops.h" +#include "nnacl/matmul_parameter.h" +namespace mindspore::kernel { +class FullconnectionNPUKernel : public ConvolutionBaseNPUKernel { + public: + FullconnectionNPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : ConvolutionBaseNPUKernel(parameter, inputs, outputs, ctx, primitive) { + fc_param_ = reinterpret_cast(parameter); + } + ~FullconnectionNPUKernel() override; + + int IsSupport(const std::vector &inputs, const std::vector &outputs, + OpParameter *opParameter) override; + int SetNPUInputs(const std::vector &inputs, const std::vector &outputs, + const std::vector &npu_inputs) override; + ge::Operator *GetNPUOp() override; + + private: + hiai::op::Reshape *reshape_ = nullptr; + hiai::op::MatMul *fc_ = nullptr; + hiai::op::BiasAdd *biasadd_ = nullptr; + MatMulParameter *fc_param_; +}; +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_FULLCONNECTION_NPU_H_ diff --git a/mindspore/lite/src/runtime/kernel/npu/reduce_npu.cc b/mindspore/lite/src/runtime/kernel/npu/reduce_npu.cc new file mode 100644 index 0000000000..48460972da --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/npu/reduce_npu.cc @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/npu/reduce_npu.h" +#include +#include "src/kernel_registry.h" +#include "include/graph/op/all_ops.h" +#include "src/runtime/agent/npu/npu_converter_utils.h" +using mindspore::kernel::KERNEL_ARCH::kNPU; +using mindspore::lite::KernelRegistrar; +using mindspore::schema::PrimitiveType_Reduce; +using mindspore::schema::ReduceMode_ReduceMean; + +namespace mindspore::kernel { +int ReduceNPUKernel::IsSupport(const std::vector &inputs, const std::vector &outputs, + OpParameter *opParameter) { + if (reduce_param_->mode_ != ReduceMode_ReduceMean) { + MS_LOG(ERROR) << "Npu does not support reduce mode " << reduce_param_->mode_ << " for op " << name_; + return RET_ERROR; + } + if (reduce_param_->reduce_to_end_) { + return RET_ERROR; + } + return RET_OK; +} + +int ReduceNPUKernel::SetNPUInputs(const std::vector &inputs, const std::vector &outputs, + const std::vector &npu_inputs) { + std::vector axes; + for (int i = 0; i < reduce_param_->num_axes_; i++) { + axes.push_back(reduce_param_->axes_[i]); + } + auto axes_op = new (std::nothrow) hiai::op::Const(name_ + "_reduce_axes"); + ge::TensorDesc axes_tensor_desc(ge::Shape({reduce_param_->num_axes_}), ge::FORMAT_NCHW, ge::DT_INT32); + ge::TensorPtr axes_tensor = std::make_shared(axes_tensor_desc); + axes_tensor->SetData(reinterpret_cast(axes.data()), reduce_param_->num_axes_ * sizeof(int32_t)); + axes_op->set_attr_value(axes_tensor); + + auto reduce_mean_ = new (std::nothrow) hiai::op::ReduceMean(name_); + if (reduce_mean_ == nullptr) { + MS_LOG(ERROR) << "New reduce operator for op " << name_ << " failed."; + return RET_ERROR; + } + reduce_mean_->set_input_x(*npu_inputs[0]).set_input_axes(*axes_op).set_attr_keep_dims(reduce_param_->keep_dims_); + reduce_ = reduce_mean_; + return RET_OK; +} + +ge::Operator *mindspore::kernel::ReduceNPUKernel::GetNPUOp() { return this->reduce_; } + +ReduceNPUKernel::~ReduceNPUKernel() { + if (reduce_ != nullptr) { + delete reduce_; + reduce_ = nullptr; + } +} + +REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_Reduce, NPUKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/npu/reduce_npu.h b/mindspore/lite/src/runtime/kernel/npu/reduce_npu.h new file mode 100644 index 0000000000..81be544b29 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/npu/reduce_npu.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_REDUCE_NPU_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_REDUCE_NPU_H_ +#include +#include "nnacl/reduce_parameter.h" +#include "src/runtime/kernel/npu/npu_kernel.h" +#include "include/graph/op/all_ops.h" +namespace mindspore::kernel { +class ReduceNPUKernel : public NPUKernel { + public: + ReduceNPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : NPUKernel(parameter, inputs, outputs, ctx, primitive) { + reduce_param_ = reinterpret_cast(parameter); + } + ~ReduceNPUKernel() override; + + int IsSupport(const std::vector &inputs, const std::vector &outputs, + OpParameter *opParameter) override; + int SetNPUInputs(const std::vector &inputs, const std::vector &outputs, + const std::vector &npu_inputs) override; + ge::Operator *GetNPUOp() override; + + private: + ReduceParameter *reduce_param_; + hiai::Operator *reduce_ = nullptr; +}; +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_REDUCE_NPU_H_ diff --git a/mindspore/lite/test/models_caffe.cfg b/mindspore/lite/test/models_caffe.cfg index fc7114da1a..bc716907e6 100644 --- a/mindspore/lite/test/models_caffe.cfg +++ b/mindspore/lite/test/models_caffe.cfg @@ -68,3 +68,4 @@ ml_location_scene_division ml_tabel_recog ml_text_division 6c_seg_nomean_20200610 +ml_video_edit_person_divison diff --git a/mindspore/lite/test/models_npu.cfg b/mindspore/lite/test/models_npu.cfg index 82466639b8..640a2cbe6f 100644 --- a/mindspore/lite/test/models_npu.cfg +++ b/mindspore/lite/test/models_npu.cfg @@ -1,5 +1,28 @@ +mobilenet_v1_0.25_128.tflite 2.5 +mobilenet_v1_0.25_160.tflite 2.5 +mobilenet_v1_0.25_192.tflite 1.5 +mobilenet_v1_0.25_224.tflite 2 +mobilenet_v1_0.5_128.tflite 2 +mobilenet_v1_0.5_160.tflite 2 +mobilenet_v1_0.5_192.tflite 2.5 +mobilenet_v1_0.5_224.tflite 2 +mobilenet_v1_0.75_128.tflite 3 +mobilenet_v1_0.75_160.tflite 3 +mobilenet_v1_0.75_192.tflite 3.5 +mobilenet_v1_0.75_224.tflite 1.5 +mobilenet_v1_1.0_128.tflite 6 +mobilenet_v1_1.0_160.tflite 2 +mobilenet_v1_1.0_192.tflite 6 +mobilenet_v1_1.0_224.tflite 2.5 mobilenet_v2_1.0_224.tflite 2.5 squeezenet.tflite 2.5 inception_v3.tflite 1 +inception_v4.tflite 0.5 +efficientnet_lite0_fp32_2.tflite 1 +efficientnet_lite1_fp32_2.tflite 1 +efficientnet_lite2_fp32_2.tflite 1 +efficientnet_lite3_fp32_2.tflite 1 +efficientnet_lite4_fp32_2.tflite 1 6c_seg_nomean_20200610 1.5 +ml_video_edit_person_divison 0.5 porseg_tmp.onnx 1 2