| @@ -21,6 +21,9 @@ | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| bool CheckFusion(kernel::LiteKernel *kernel) { | bool CheckFusion(kernel::LiteKernel *kernel) { | ||||
| if (kernel->in_kernels().empty() || kernel->out_kernels().empty()) { | |||||
| return false; | |||||
| } | |||||
| auto pre_flag = | auto pre_flag = | ||||
| std::all_of(kernel->in_kernels().begin(), kernel->in_kernels().end(), [](const kernel::LiteKernel *in_kernel) { | std::all_of(kernel->in_kernels().begin(), kernel->in_kernels().end(), [](const kernel::LiteKernel *in_kernel) { | ||||
| return NPUPassUtils::IsNchw2Nhwc(in_kernel) && in_kernel->out_kernels().size() == 1; | return NPUPassUtils::IsNchw2Nhwc(in_kernel) && in_kernel->out_kernels().size() == 1; | ||||
| @@ -34,7 +34,7 @@ ConvolutionBaseNPUKernel::~ConvolutionBaseNPUKernel() { | |||||
| } | } | ||||
| } | } | ||||
| int ConvolutionBaseNPUKernel::InitWeightBiasConst(const std::vector<lite::Tensor *> &inputs) { | |||||
| int ConvolutionBaseNPUKernel::InitWeightConst(const std::vector<lite::Tensor *> &inputs) { | |||||
| weight_ = new (std::nothrow) hiai::op::Const(name_ + "_w"); | weight_ = new (std::nothrow) hiai::op::Const(name_ + "_w"); | ||||
| if (weight_ == nullptr) { | if (weight_ == nullptr) { | ||||
| MS_LOG(ERROR) << "New weight const failed."; | MS_LOG(ERROR) << "New weight const failed."; | ||||
| @@ -61,7 +61,10 @@ int ConvolutionBaseNPUKernel::InitWeightBiasConst(const std::vector<lite::Tensor | |||||
| weight_->set_attr_value(weight_tensor); | weight_->set_attr_value(weight_tensor); | ||||
| free(nchw_data); | free(nchw_data); | ||||
| return RET_OK; | |||||
| } | |||||
| int ConvolutionBaseNPUKernel::InitBiasConst(const std::vector<lite::Tensor *> &inputs) { | |||||
| if (inputs.size() >= 3) { | if (inputs.size() >= 3) { | ||||
| bias_ = new (std::nothrow) hiai::op::Const(name_ + "_b"); | bias_ = new (std::nothrow) hiai::op::Const(name_ + "_b"); | ||||
| if (bias_ == nullptr) { | if (bias_ == nullptr) { | ||||
| @@ -88,7 +91,7 @@ int ConvolutionBaseNPUKernel::SetActivation(const ge::Operator *input, ActType a | |||||
| } else if (act_type == ActType_Relu6) { | } else if (act_type == ActType_Relu6) { | ||||
| act_->set_attr_mode(14); | act_->set_attr_mode(14); | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupport activation for convolution."; | |||||
| MS_LOG(ERROR) << "Unsupport activation type for convolution."; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -32,7 +32,8 @@ class ConvolutionBaseNPUKernel : public NPUKernel { | |||||
| ~ConvolutionBaseNPUKernel() override; | ~ConvolutionBaseNPUKernel() override; | ||||
| protected: | protected: | ||||
| int InitWeightBiasConst(const std::vector<lite::Tensor *> &inputs); | |||||
| int InitWeightConst(const std::vector<lite::Tensor *> &inputs); | |||||
| int InitBiasConst(const std::vector<lite::Tensor *> &inputs); | |||||
| int SetActivation(const ge::Operator *input, ActType act_type); | int SetActivation(const ge::Operator *input, ActType act_type); | ||||
| hiai::op::Activation *act_ = nullptr; | hiai::op::Activation *act_ = nullptr; | ||||
| hiai::op::Const *weight_ = nullptr; | hiai::op::Const *weight_ = nullptr; | ||||
| @@ -39,7 +39,7 @@ int ConvolutionDepthwiseNPUKernel::SetConvDwParam() { | |||||
| conv_dw_->set_attr_pad_mode(ge::AttrValue::STR{"VALID"}); | conv_dw_->set_attr_pad_mode(ge::AttrValue::STR{"VALID"}); | ||||
| conv_dw_->set_attr_pads(ge::AttrValue::LIST_INT({0, 0, 0, 0})); | conv_dw_->set_attr_pads(ge::AttrValue::LIST_INT({0, 0, 0, 0})); | ||||
| } else { | } else { | ||||
| conv_dw_->set_attr_pad_mode(ge::AttrValue::STR{"SPECIFIC"}); | |||||
| conv_dw_->set_attr_pad_mode(ge::AttrValue::STR{"VALID"}); | |||||
| conv_dw_->set_attr_pads( | conv_dw_->set_attr_pads( | ||||
| ge::AttrValue::LIST_INT({conv_param_->pad_u_, conv_param_->pad_d_, conv_param_->pad_l_, conv_param_->pad_r_})); | ge::AttrValue::LIST_INT({conv_param_->pad_u_, conv_param_->pad_d_, conv_param_->pad_l_, conv_param_->pad_r_})); | ||||
| } | } | ||||
| @@ -61,13 +61,19 @@ int ConvolutionDepthwiseNPUKernel::SetNPUInputs(const std::vector<lite::Tensor * | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| ret = InitWeightBiasConst(inputs); | |||||
| ret = InitWeightConst(inputs); | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "Set weight and bias for convolution depthwise op " << name_ << " failed when running npu"; | MS_LOG(ERROR) << "Set weight and bias for convolution depthwise op " << name_ << " failed when running npu"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| conv_dw_->set_input_filter(*weight_); | conv_dw_->set_input_filter(*weight_); | ||||
| if (inputs.size() == 3) { | if (inputs.size() == 3) { | ||||
| ret = InitBiasConst(inputs); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Set bias for convolution depthwise op " << name_ << " failed when running npu"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| conv_dw_->set_input_bias(*bias_); | conv_dw_->set_input_bias(*bias_); | ||||
| } | } | ||||
| conv_dw_->set_input_x(*npu_inputs[0]); | conv_dw_->set_input_x(*npu_inputs[0]); | ||||
| @@ -65,13 +65,19 @@ int ConvolutionNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| ret = InitWeightBiasConst(inputs); | |||||
| ret = InitWeightConst(inputs); | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "Set weight and bias for convolution op " << name_ << " failed when running npu"; | MS_LOG(ERROR) << "Set weight and bias for convolution op " << name_ << " failed when running npu"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| conv_->set_input_filter(*weight_); | conv_->set_input_filter(*weight_); | ||||
| if (inputs.size() == 3) { | if (inputs.size() == 3) { | ||||
| ret = InitBiasConst(inputs); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Set bias for convolution op " << name_ << " failed when running npu"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| conv_->set_input_bias(*bias_); | conv_->set_input_bias(*bias_); | ||||
| } | } | ||||
| conv_->set_input_x(*npu_inputs[0]); | conv_->set_input_x(*npu_inputs[0]); | ||||
| @@ -65,13 +65,19 @@ int DeconvolutionNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inpu | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| ret = InitWeightBiasConst(inputs); | |||||
| ret = InitWeightConst(inputs); | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "Set weight and bias for deconvolution op " << name_ << " failed when running npu"; | MS_LOG(ERROR) << "Set weight and bias for deconvolution op " << name_ << " failed when running npu"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| deconv_->set_input_filter(*weight_); | deconv_->set_input_filter(*weight_); | ||||
| if (inputs.size() == 3) { | if (inputs.size() == 3) { | ||||
| ret = InitBiasConst(inputs); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Set bias for deconvolution op " << name_ << " failed when running npu"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| deconv_->set_input_bias(*bias_); | deconv_->set_input_bias(*bias_); | ||||
| } | } | ||||
| deconv_->set_input_x(*npu_inputs[0]); | deconv_->set_input_x(*npu_inputs[0]); | ||||
| @@ -0,0 +1,122 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/npu/fullconnection_npu.h" | |||||
| #include <memory> | |||||
| #include "src/kernel_registry.h" | |||||
| #include "src/runtime/agent/npu/npu_converter_utils.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kNPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::schema::PrimitiveType_FullConnection; | |||||
| namespace mindspore::kernel { | |||||
| int FullconnectionNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter) { | |||||
| return RET_OK; | |||||
| } | |||||
| int FullconnectionNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, | |||||
| const std::vector<ge::Operator *> &npu_inputs) { | |||||
| auto input_shape = inputs[0]->shape(); | |||||
| reshape_ = new (std::nothrow) hiai::op::Reshape(name_ + "_reshape"); | |||||
| if (reshape_ == nullptr) { | |||||
| MS_LOG(ERROR) << "New reshape operator for fullconnection op " << name_ << " failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| reshape_->set_input_x(*npu_inputs[0]); | |||||
| int col = 1; | |||||
| for (int i = 1; i < input_shape.size(); i++) { | |||||
| col *= input_shape[i]; | |||||
| } | |||||
| auto reshape_op = new (std::nothrow) hiai::op::Const(name_ + "_reshape_data"); | |||||
| vector<int> reshape_data = {input_shape[0], col}; | |||||
| ge::TensorDesc reshape_tensor_desc(ge::Shape({2}), ge::FORMAT_NCHW, ge::DT_FLOAT); | |||||
| ge::TensorPtr reshape_tensor = std::make_shared<hiai::Tensor>(reshape_tensor_desc); | |||||
| reshape_tensor->SetData(reinterpret_cast<uint8_t *>(reshape_data.data()), 2 * sizeof(float)); | |||||
| reshape_op->set_attr_value(reshape_tensor); | |||||
| reshape_->set_input_shape(*reshape_op); | |||||
| fc_ = new (std::nothrow) hiai::op::MatMul(name_); | |||||
| if (fc_ == nullptr) { | |||||
| MS_LOG(ERROR) << "New matmul operator for fullconnection op " << name_ << " failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| fc_->set_input_x1(*reshape_); | |||||
| weight_ = new (std::nothrow) hiai::op::Const(name_ + "_w"); | |||||
| if (weight_ == nullptr) { | |||||
| MS_LOG(ERROR) << "New weight const failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| inputs[1]->set_format(schema::Format_NCHW); | |||||
| auto weight_tensor = mindspore::lite::ConverterToNPUTensor(inputs[1]); | |||||
| weight_->set_attr_value(weight_tensor); | |||||
| inputs[1]->set_format(schema::Format_NHWC); | |||||
| fc_->set_input_x2(*weight_).set_attr_transpose_x2(true); | |||||
| if (fc_param_->has_bias_) { | |||||
| biasadd_ = new (std::nothrow) hiai::op::BiasAdd(name_ + "_biasadd"); | |||||
| if (biasadd_ == nullptr) { | |||||
| MS_LOG(ERROR) << "New biasadd operator for fullconnection op " << name_ << " failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto ret = InitBiasConst(inputs); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Set bias for convolution op " << name_ << " failed when running npu"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| biasadd_->set_input_x(*fc_).set_input_bias(*bias_); | |||||
| } | |||||
| if (fc_param_->act_type_ != ActType_No) { | |||||
| auto ret = | |||||
| biasadd_ == nullptr ? SetActivation(fc_, fc_param_->act_type_) : SetActivation(biasadd_, fc_param_->act_type_); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "New activation npu operator for op " << name_ << " failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| ge::Operator *mindspore::kernel::FullconnectionNPUKernel::GetNPUOp() { | |||||
| if (fc_param_->act_type_ != ActType_No) { | |||||
| return act_; | |||||
| } | |||||
| if (fc_param_->has_bias_) { | |||||
| return biasadd_; | |||||
| } | |||||
| return fc_; | |||||
| } | |||||
| FullconnectionNPUKernel::~FullconnectionNPUKernel() { | |||||
| if (reshape_ != nullptr) { | |||||
| delete reshape_; | |||||
| reshape_ = nullptr; | |||||
| } | |||||
| if (fc_ != nullptr) { | |||||
| delete fc_; | |||||
| fc_ = nullptr; | |||||
| } | |||||
| if (biasadd_ != nullptr) { | |||||
| delete biasadd_; | |||||
| biasadd_ = nullptr; | |||||
| } | |||||
| } | |||||
| REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_FullConnection, NPUKernelCreator<FullconnectionNPUKernel>) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,47 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_FULLCONNECTION_NPU_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_FULLCONNECTION_NPU_H_ | |||||
| #include <vector> | |||||
| #include "src/runtime/kernel/npu/convolution_base_npu.h" | |||||
| #include "include/graph/op/all_ops.h" | |||||
| #include "nnacl/matmul_parameter.h" | |||||
| namespace mindspore::kernel { | |||||
| class FullconnectionNPUKernel : public ConvolutionBaseNPUKernel { | |||||
| public: | |||||
| FullconnectionNPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||||
| const mindspore::lite::PrimitiveC *primitive) | |||||
| : ConvolutionBaseNPUKernel(parameter, inputs, outputs, ctx, primitive) { | |||||
| fc_param_ = reinterpret_cast<MatMulParameter *>(parameter); | |||||
| } | |||||
| ~FullconnectionNPUKernel() override; | |||||
| int IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||||
| OpParameter *opParameter) override; | |||||
| int SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||||
| const std::vector<ge::Operator *> &npu_inputs) override; | |||||
| ge::Operator *GetNPUOp() override; | |||||
| private: | |||||
| hiai::op::Reshape *reshape_ = nullptr; | |||||
| hiai::op::MatMul *fc_ = nullptr; | |||||
| hiai::op::BiasAdd *biasadd_ = nullptr; | |||||
| MatMulParameter *fc_param_; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_FULLCONNECTION_NPU_H_ | |||||
| @@ -0,0 +1,72 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/npu/reduce_npu.h" | |||||
| #include <memory> | |||||
| #include "src/kernel_registry.h" | |||||
| #include "include/graph/op/all_ops.h" | |||||
| #include "src/runtime/agent/npu/npu_converter_utils.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kNPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::schema::PrimitiveType_Reduce; | |||||
| using mindspore::schema::ReduceMode_ReduceMean; | |||||
| namespace mindspore::kernel { | |||||
| int ReduceNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||||
| OpParameter *opParameter) { | |||||
| if (reduce_param_->mode_ != ReduceMode_ReduceMean) { | |||||
| MS_LOG(ERROR) << "Npu does not support reduce mode " << reduce_param_->mode_ << " for op " << name_; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (reduce_param_->reduce_to_end_) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int ReduceNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||||
| const std::vector<ge::Operator *> &npu_inputs) { | |||||
| std::vector<int32_t> axes; | |||||
| for (int i = 0; i < reduce_param_->num_axes_; i++) { | |||||
| axes.push_back(reduce_param_->axes_[i]); | |||||
| } | |||||
| auto axes_op = new (std::nothrow) hiai::op::Const(name_ + "_reduce_axes"); | |||||
| ge::TensorDesc axes_tensor_desc(ge::Shape({reduce_param_->num_axes_}), ge::FORMAT_NCHW, ge::DT_INT32); | |||||
| ge::TensorPtr axes_tensor = std::make_shared<hiai::Tensor>(axes_tensor_desc); | |||||
| axes_tensor->SetData(reinterpret_cast<uint8_t *>(axes.data()), reduce_param_->num_axes_ * sizeof(int32_t)); | |||||
| axes_op->set_attr_value(axes_tensor); | |||||
| auto reduce_mean_ = new (std::nothrow) hiai::op::ReduceMean(name_); | |||||
| if (reduce_mean_ == nullptr) { | |||||
| MS_LOG(ERROR) << "New reduce operator for op " << name_ << " failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| reduce_mean_->set_input_x(*npu_inputs[0]).set_input_axes(*axes_op).set_attr_keep_dims(reduce_param_->keep_dims_); | |||||
| reduce_ = reduce_mean_; | |||||
| return RET_OK; | |||||
| } | |||||
| ge::Operator *mindspore::kernel::ReduceNPUKernel::GetNPUOp() { return this->reduce_; } | |||||
| ReduceNPUKernel::~ReduceNPUKernel() { | |||||
| if (reduce_ != nullptr) { | |||||
| delete reduce_; | |||||
| reduce_ = nullptr; | |||||
| } | |||||
| } | |||||
| REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_Reduce, NPUKernelCreator<ReduceNPUKernel>) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,45 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_REDUCE_NPU_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_REDUCE_NPU_H_ | |||||
| #include <vector> | |||||
| #include "nnacl/reduce_parameter.h" | |||||
| #include "src/runtime/kernel/npu/npu_kernel.h" | |||||
| #include "include/graph/op/all_ops.h" | |||||
| namespace mindspore::kernel { | |||||
| class ReduceNPUKernel : public NPUKernel { | |||||
| public: | |||||
| ReduceNPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||||
| const mindspore::lite::PrimitiveC *primitive) | |||||
| : NPUKernel(parameter, inputs, outputs, ctx, primitive) { | |||||
| reduce_param_ = reinterpret_cast<ReduceParameter *>(parameter); | |||||
| } | |||||
| ~ReduceNPUKernel() override; | |||||
| int IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||||
| OpParameter *opParameter) override; | |||||
| int SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||||
| const std::vector<ge::Operator *> &npu_inputs) override; | |||||
| ge::Operator *GetNPUOp() override; | |||||
| private: | |||||
| ReduceParameter *reduce_param_; | |||||
| hiai::Operator *reduce_ = nullptr; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_REDUCE_NPU_H_ | |||||
| @@ -68,3 +68,4 @@ ml_location_scene_division | |||||
| ml_tabel_recog | ml_tabel_recog | ||||
| ml_text_division | ml_text_division | ||||
| 6c_seg_nomean_20200610 | 6c_seg_nomean_20200610 | ||||
| ml_video_edit_person_divison | |||||
| @@ -1,5 +1,28 @@ | |||||
| mobilenet_v1_0.25_128.tflite 2.5 | |||||
| mobilenet_v1_0.25_160.tflite 2.5 | |||||
| mobilenet_v1_0.25_192.tflite 1.5 | |||||
| mobilenet_v1_0.25_224.tflite 2 | |||||
| mobilenet_v1_0.5_128.tflite 2 | |||||
| mobilenet_v1_0.5_160.tflite 2 | |||||
| mobilenet_v1_0.5_192.tflite 2.5 | |||||
| mobilenet_v1_0.5_224.tflite 2 | |||||
| mobilenet_v1_0.75_128.tflite 3 | |||||
| mobilenet_v1_0.75_160.tflite 3 | |||||
| mobilenet_v1_0.75_192.tflite 3.5 | |||||
| mobilenet_v1_0.75_224.tflite 1.5 | |||||
| mobilenet_v1_1.0_128.tflite 6 | |||||
| mobilenet_v1_1.0_160.tflite 2 | |||||
| mobilenet_v1_1.0_192.tflite 6 | |||||
| mobilenet_v1_1.0_224.tflite 2.5 | |||||
| mobilenet_v2_1.0_224.tflite 2.5 | mobilenet_v2_1.0_224.tflite 2.5 | ||||
| squeezenet.tflite 2.5 | squeezenet.tflite 2.5 | ||||
| inception_v3.tflite 1 | inception_v3.tflite 1 | ||||
| inception_v4.tflite 0.5 | |||||
| efficientnet_lite0_fp32_2.tflite 1 | |||||
| efficientnet_lite1_fp32_2.tflite 1 | |||||
| efficientnet_lite2_fp32_2.tflite 1 | |||||
| efficientnet_lite3_fp32_2.tflite 1 | |||||
| efficientnet_lite4_fp32_2.tflite 1 | |||||
| 6c_seg_nomean_20200610 1.5 | 6c_seg_nomean_20200610 1.5 | ||||
| ml_video_edit_person_divison 0.5 | |||||
| porseg_tmp.onnx 1 2 | porseg_tmp.onnx 1 2 | ||||