| @@ -150,6 +150,7 @@ int Split::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu | |||||
| for (size_t j = 0; j < size_splits_.size() - 1; ++j) { | for (size_t j = 0; j < size_splits_.size() - 1; ++j) { | ||||
| split_dim_i -= size_splits_[j]; | split_dim_i -= size_splits_[j]; | ||||
| } | } | ||||
| size_splits_[i] = split_dim_i; | |||||
| } else { | } else { | ||||
| split_dim_i = size_splits_[i]; | split_dim_i = size_splits_[i]; | ||||
| } | } | ||||
| @@ -24,6 +24,7 @@ ge::Shape ConverterToNPUShape(const std::vector<int> &src_shape) { | |||||
| } | } | ||||
| return ge::Shape({shapes}); | return ge::Shape({shapes}); | ||||
| } | } | ||||
| ge::Format ConverterToNPUFormat(schema::Format format) { | ge::Format ConverterToNPUFormat(schema::Format format) { | ||||
| ge::Format ge_format; | ge::Format ge_format; | ||||
| switch (format) { | switch (format) { | ||||
| @@ -74,13 +75,14 @@ ge::DataType ConverterToNPUDataType(TypeId type_id) { | |||||
| } | } | ||||
| return data_type; | return data_type; | ||||
| } | } | ||||
| hiai::op::Data *ConverterToNPUData(Tensor *src, const std::string &name) { | hiai::op::Data *ConverterToNPUData(Tensor *src, const std::string &name) { | ||||
| auto data = new (std::nothrow) hiai::op::Data(name); | auto data = new (std::nothrow) hiai::op::Data(name); | ||||
| if (data == nullptr) { | if (data == nullptr) { | ||||
| MS_LOG(ERROR) << "new data failed."; | MS_LOG(ERROR) << "new data failed."; | ||||
| return data; | return data; | ||||
| } | } | ||||
| ge::TensorDesc tensor_desc(ConverterToNPUShape(src->shape()), ge::FORMAT_NCHW, | |||||
| ge::TensorDesc tensor_desc(ConverterToNPUShape(src->shape()), ConverterToNPUFormat(src->format()), | |||||
| ConverterToNPUDataType(src->data_type())); | ConverterToNPUDataType(src->data_type())); | ||||
| data->update_input_desc_x(tensor_desc); | data->update_input_desc_x(tensor_desc); | ||||
| return data; | return data; | ||||
| @@ -92,7 +94,7 @@ std::shared_ptr<ge::Tensor> ConverterToNPUTensor(Tensor *src) { | |||||
| MS_LOG(ERROR) << "new ge_tensor failed."; | MS_LOG(ERROR) << "new ge_tensor failed."; | ||||
| return ge_tensor; | return ge_tensor; | ||||
| } | } | ||||
| ge::TensorDesc tensor_desc(ConverterToNPUShape(src->shape()), ge::FORMAT_NCHW, | |||||
| ge::TensorDesc tensor_desc(ConverterToNPUShape(src->shape()), ConverterToNPUFormat(src->format()), | |||||
| ConverterToNPUDataType(src->data_type())); | ConverterToNPUDataType(src->data_type())); | ||||
| ge_tensor->SetTensorDesc(tensor_desc); | ge_tensor->SetTensorDesc(tensor_desc); | ||||
| @@ -102,62 +104,7 @@ std::shared_ptr<ge::Tensor> ConverterToNPUTensor(Tensor *src) { | |||||
| } | } | ||||
| return ge_tensor; | return ge_tensor; | ||||
| } | } | ||||
| /* | |||||
| * mode : Activation mode, with options as follows: | |||||
| * 0 : Sigmoid | |||||
| * 1 : ReLU | |||||
| * 2 : Tanh | |||||
| * 3 : Clipped ReLU | |||||
| * 4 : ELU | |||||
| * 5 : PReLU | |||||
| * 6 : Abs | |||||
| * 7 : Relu1 | |||||
| * 8 : Softsign | |||||
| * 9 : Softplus | |||||
| * 10 : Hardsigmoid | |||||
| * 11 : Threshold ReLU | |||||
| * 12 : Selu | |||||
| * 13 : Linear | |||||
| * 14 : Relu6 | |||||
| * 15 : GeLU. | |||||
| */ | |||||
| int ConverterToNPUActMode(schema::ActivationType type) { | |||||
| switch (type) { | |||||
| case schema::ActivationType_NO_ACTIVATION: | |||||
| return -1; | |||||
| case schema::ActivationType_SIGMOID: | |||||
| return 0; | |||||
| case schema::ActivationType_RELU: | |||||
| return 1; | |||||
| case schema::ActivationType_TANH: | |||||
| return 2; | |||||
| case schema::ActivationType_ELU: | |||||
| return 4; | |||||
| case schema::ActivationType_LEAKY_RELU: | |||||
| return 5; | |||||
| case schema::ActivationType_ABS: | |||||
| return 6; | |||||
| case schema::ActivationType_RELU1: | |||||
| return 7; | |||||
| case schema::ActivationType_SOFTSIGN: | |||||
| return 8; | |||||
| case schema::ActivationType_SOFTPLUS: | |||||
| return 9; | |||||
| case schema::ActivationType_HSIGMOID: | |||||
| return 10; | |||||
| case schema::ActivationType_THRESHOLDRELU: | |||||
| return 11; | |||||
| case schema::ActivationType_SELU: | |||||
| return 12; | |||||
| case schema::ActivationType_LINEAR: | |||||
| return 13; | |||||
| case schema::ActivationType_RELU6: | |||||
| return 14; | |||||
| default: | |||||
| MS_LOG(ERROR) << "Unsupport activation type to NPU." << type; | |||||
| return -1; | |||||
| } | |||||
| } | |||||
| // mode : Either 0 (product), 1 (sum), 2 (max), 3 (mean). Defaults to 1 (sum). | // mode : Either 0 (product), 1 (sum), 2 (max), 3 (mean). Defaults to 1 (sum). | ||||
| int ConverterToNPUEltwiseMode(schema::EltwiseMode mode) { | int ConverterToNPUEltwiseMode(schema::EltwiseMode mode) { | ||||
| int mode_num = 1; | int mode_num = 1; | ||||
| @@ -53,6 +53,7 @@ int NPUExecutor::Run(std::vector<Tensor *> &in_tensors, std::vector<Tensor *> &o | |||||
| for (int i = 0; i < npu_output_tensors_.size(); ++i) { | for (int i = 0; i < npu_output_tensors_.size(); ++i) { | ||||
| memcpy(out_tensors[i]->MutableData(), npu_output_tensors_[i]->GetBuffer(), npu_output_tensors_[i]->GetSize()); | memcpy(out_tensors[i]->MutableData(), npu_output_tensors_[i]->GetBuffer(), npu_output_tensors_[i]->GetSize()); | ||||
| out_tensors[i]->ResetRefCount(); | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -83,16 +83,22 @@ int SubGraphNpuKernel::BuildNPUInputOp() { | |||||
| for (auto in_tensor : node->in_tensors()) { | for (auto in_tensor : node->in_tensors()) { | ||||
| if (IsSubGraphInputTensor(in_tensor)) { | if (IsSubGraphInputTensor(in_tensor)) { | ||||
| auto tensor_name = node->name() + "_" + std::to_string(count++); | auto tensor_name = node->name() + "_" + std::to_string(count++); | ||||
| auto shape = in_tensor->shape(); | |||||
| hiai::op::Data *data; | |||||
| if (trans_nodes.find(node->Type()) != trans_nodes.end()) { | if (trans_nodes.find(node->Type()) != trans_nodes.end()) { | ||||
| in_tensor->set_shape({shape[0], shape[3], shape[1], shape[2]}); | |||||
| auto shape = in_tensor->shape(); | |||||
| data = new (std::nothrow) hiai::op::Data(tensor_name); | |||||
| if (data == nullptr) { | |||||
| MS_LOG(ERROR) << "New data failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| ge::TensorDesc tensor_desc(lite::ConverterToNPUShape({shape[0], shape[3], shape[1], shape[2]}), | |||||
| ge::FORMAT_NCHW, lite::ConverterToNPUDataType(in_tensor->data_type())); | |||||
| data->update_input_desc_x(tensor_desc); | |||||
| } else { | |||||
| data = mindspore::lite::ConverterToNPUData(in_tensor, tensor_name); | |||||
| } | } | ||||
| auto data = mindspore::lite::ConverterToNPUData(in_tensor, tensor_name); | |||||
| subgraph_input_op_.push_back(*data); | subgraph_input_op_.push_back(*data); | ||||
| node_input_op.push_back(data); | node_input_op.push_back(data); | ||||
| if (trans_nodes.find(node->Type()) != trans_nodes.end()) { | |||||
| in_tensor->set_shape(shape); | |||||
| } | |||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -120,13 +126,11 @@ int SubGraphNpuKernel::BuildNPUInputOp() { | |||||
| // weight tensor | // weight tensor | ||||
| if (is_weight_tensor) { | if (is_weight_tensor) { | ||||
| if (!(node->Type() == schema::PrimitiveType_Conv2D || node->Type() == schema::PrimitiveType_DeConv2D || | |||||
| node->Type() == schema::PrimitiveType_DepthwiseConv2D || | |||||
| node->Type() == schema::PrimitiveType_DeDepthwiseConv2D)) { | |||||
| if (trans_nodes.find(node->Type()) == trans_nodes.end()) { | |||||
| auto name = node->name() + "_" + std::to_string(count++); | auto name = node->name() + "_" + std::to_string(count++); | ||||
| auto weight_const = new (std::nothrow) hiai::op::Const(node->name() + "_" + std::to_string(count++)); | auto weight_const = new (std::nothrow) hiai::op::Const(node->name() + "_" + std::to_string(count++)); | ||||
| if (weight_const == nullptr) { | if (weight_const == nullptr) { | ||||
| MS_LOG(ERROR) << "new weight const failed."; | |||||
| MS_LOG(ERROR) << "New weight const failed."; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto weight_tensor = mindspore::lite::ConverterToNPUTensor(in_tensor); | auto weight_tensor = mindspore::lite::ConverterToNPUTensor(in_tensor); | ||||
| @@ -24,7 +24,7 @@ using mindspore::schema::PrimitiveType_Concat; | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| int ConcatNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | int ConcatNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | ||||
| OpParameter *opParameter) { | OpParameter *opParameter) { | ||||
| return RET_OK; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| int ConcatNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | int ConcatNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | ||||
| @@ -25,7 +25,7 @@ using mindspore::schema::PrimitiveType_DepthwiseConv2D; | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| int ConvolutionDepthwiseNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, | int ConvolutionDepthwiseNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter) { | const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter) { | ||||
| return RET_OK; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| int ConvolutionDepthwiseNPUKernel::SetConvDwParam() { | int ConvolutionDepthwiseNPUKernel::SetConvDwParam() { | ||||
| @@ -24,7 +24,7 @@ using mindspore::schema::PrimitiveType_Conv2D; | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| int ConvolutionNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, | int ConvolutionNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter) { | const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter) { | ||||
| return RET_OK; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| int ConvolutionNPUKernel::SetConvParam() { | int ConvolutionNPUKernel::SetConvParam() { | ||||
| @@ -0,0 +1,51 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/npu/matmul_npu.h" | |||||
| #include "src/kernel_registry.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kNPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::schema::PrimitiveType_MatMul; | |||||
| namespace mindspore::kernel { | |||||
| int MatMulNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||||
| OpParameter *opParameter) { | |||||
| return RET_OK; | |||||
| } | |||||
| int MatMulNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||||
| const std::vector<ge::Operator *> &npu_inputs) { | |||||
| op_ = new (std::nothrow) hiai::op::MatMul(name_); | |||||
| op_->set_input_x1(*npu_inputs[0]); | |||||
| op_->set_input_x2(*npu_inputs[1]); | |||||
| op_->set_attr_transpose_x1(a_transpose_); | |||||
| op_->set_attr_transpose_x2(b_transpose_); | |||||
| return RET_OK; | |||||
| } | |||||
| ge::Operator *mindspore::kernel::MatMulNPUKernel::GetNPUOp() { return this->op_; } | |||||
| MatMulNPUKernel::~MatMulNPUKernel() { | |||||
| if (op_ != nullptr) { | |||||
| delete op_; | |||||
| op_ = nullptr; | |||||
| } | |||||
| } | |||||
| REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_MatMul, NPUKernelCreator<MatMulNPUKernel>) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,49 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_MATMUL_NPU_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_MATMUL_NPU_H_ | |||||
| #include <vector> | |||||
| #include "nnacl/matmul_parameter.h" | |||||
| #include "src/runtime/kernel/npu/npu_kernel.h" | |||||
| #include "nnacl/softmax_parameter.h" | |||||
| #include "include/graph/op/all_ops.h" | |||||
| namespace mindspore::kernel { | |||||
| class MatMulNPUKernel : public NPUKernel { | |||||
| public: | |||||
| MatMulNPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||||
| const mindspore::lite::PrimitiveC *primitive) | |||||
| : NPUKernel(parameter, inputs, outputs, ctx, primitive) { | |||||
| auto matmul_parameter = reinterpret_cast<MatMulParameter *>(parameter); | |||||
| a_transpose_ = matmul_parameter->a_transpose_; | |||||
| b_transpose_ = matmul_parameter->b_transpose_; | |||||
| } | |||||
| ~MatMulNPUKernel() override; | |||||
| int IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||||
| OpParameter *opParameter) override; | |||||
| int SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||||
| const std::vector<ge::Operator *> &npu_inputs) override; | |||||
| ge::Operator *GetNPUOp() override; | |||||
| private: | |||||
| hiai::op::MatMul *op_ = nullptr; | |||||
| bool a_transpose_ = false; | |||||
| bool b_transpose_ = false; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_MATMUL_NPU_H_ | |||||
| @@ -0,0 +1,73 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/npu/pad_npu.h" | |||||
| #include <memory> | |||||
| #include "src/kernel_registry.h" | |||||
| #include "src/runtime/agent/npu/npu_converter_utils.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kNPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::schema::PrimitiveType_Pad; | |||||
| namespace mindspore::kernel { | |||||
| int PadNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||||
| OpParameter *opParameter) { | |||||
| if (padding_mode_ != schema::PaddingMode_CONSTANT) { | |||||
| MS_LOG(WARNING) << "NPU only support CONSTANT padding mode"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int PadNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||||
| const std::vector<ge::Operator *> &npu_inputs) { | |||||
| op_ = new (std::nothrow) hiai::op::PadV2(name_); | |||||
| if (op_ == nullptr) { | |||||
| MS_LOG(ERROR) << name_ << " op is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| int size = static_cast<int>(paddings_.size() / 2); | |||||
| ge::TensorDesc padding_tensor_desc(ge::Shape({size, 2}), ge::FORMAT_NCHW, ge::DT_INT32); | |||||
| ge::TensorPtr padding_tensor = std::make_shared<hiai::Tensor>(padding_tensor_desc); | |||||
| padding_tensor->SetData(reinterpret_cast<uint8_t *>(paddings_.data()), size * sizeof(int)); | |||||
| auto paddings = new hiai::op::Const(name_ + "paddings"); | |||||
| paddings->set_attr_value(padding_tensor); | |||||
| ge::TensorDesc constant_values_tensor_desc(ge::Shape({1}), ge::FORMAT_NCHW, ge::DT_FLOAT); | |||||
| ge::TensorPtr constant_values_tensor = std::make_shared<hiai::Tensor>(constant_values_tensor_desc); | |||||
| vector<float> constant_values_data_value = {constant_value_}; | |||||
| constant_values_tensor->SetData(reinterpret_cast<uint8_t *>(constant_values_data_value.data()), 1 * sizeof(float)); | |||||
| auto constant = new hiai::op::Const(name_ + "constant"); | |||||
| constant->set_attr_value(constant_values_tensor); | |||||
| op_->set_input_x(*npu_inputs[0]); | |||||
| op_->set_input_constant_values(*constant); | |||||
| op_->set_input_paddings(*paddings); | |||||
| return RET_OK; | |||||
| } | |||||
| ge::Operator *mindspore::kernel::PadNPUKernel::GetNPUOp() { return this->op_; } | |||||
| PadNPUKernel::~PadNPUKernel() { | |||||
| if (op_ != nullptr) { | |||||
| delete op_; | |||||
| op_ = nullptr; | |||||
| } | |||||
| } | |||||
| REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_Pad, NPUKernelCreator<PadNPUKernel>) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,51 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_PAD_NPU_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_PAD_NPU_H_ | |||||
| #include <vector> | |||||
| #include "nnacl/pad_parameter.h" | |||||
| #include "src/ops/pad.h" | |||||
| #include "src/runtime/kernel/npu/npu_kernel.h" | |||||
| #include "include/graph/op/all_ops.h" | |||||
| namespace mindspore::kernel { | |||||
| class PadNPUKernel : public NPUKernel { | |||||
| public: | |||||
| PadNPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||||
| const mindspore::lite::PrimitiveC *primitive) | |||||
| : NPUKernel(parameter, inputs, outputs, ctx, primitive) { | |||||
| auto pad = reinterpret_cast<const mindspore::lite::Pad *>(primitive); | |||||
| constant_value_ = pad->GetConstantValue(); | |||||
| paddings_ = pad->GetPaddings(); | |||||
| padding_mode_ = pad->GetPaddingMode(); | |||||
| } | |||||
| ~PadNPUKernel() override; | |||||
| int IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||||
| OpParameter *opParameter) override; | |||||
| int SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||||
| const std::vector<ge::Operator *> &npu_inputs) override; | |||||
| ge::Operator *GetNPUOp() override; | |||||
| private: | |||||
| hiai::op::PadV2 *op_ = nullptr; | |||||
| std::vector<int> paddings_; | |||||
| int padding_mode_; | |||||
| float constant_value_; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_PAD_NPU_H_ | |||||
| @@ -24,7 +24,7 @@ using mindspore::schema::PrimitiveType_Pooling; | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| int PoolingNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | int PoolingNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | ||||
| OpParameter *opParameter) { | OpParameter *opParameter) { | ||||
| return RET_OK; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| int PoolingNPUKernel::SetPoolingParam() { | int PoolingNPUKernel::SetPoolingParam() { | ||||
| @@ -0,0 +1,54 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/npu/slice_npu.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "src/runtime/agent/npu/npu_converter_utils.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kNPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::schema::PrimitiveType_Slice; | |||||
| namespace mindspore::kernel { | |||||
| int SliceNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||||
| OpParameter *opParameter) { | |||||
| return RET_OK; | |||||
| } | |||||
| int SliceNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||||
| const std::vector<ge::Operator *> &npu_inputs) { | |||||
| op_ = new (std::nothrow) hiai::op::Slice(name_); | |||||
| if (op_ == nullptr) { | |||||
| MS_LOG(ERROR) << name_ << " op is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| op_->set_input_x(*npu_inputs[0]); | |||||
| op_->set_input_offsets(*npu_inputs[1]); | |||||
| op_->set_input_size(*npu_inputs[2]); | |||||
| return RET_OK; | |||||
| } | |||||
| ge::Operator *mindspore::kernel::SliceNPUKernel::GetNPUOp() { return this->op_; } | |||||
| SliceNPUKernel::~SliceNPUKernel() { | |||||
| if (op_ != nullptr) { | |||||
| delete op_; | |||||
| op_ = nullptr; | |||||
| } | |||||
| } | |||||
| REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_Slice, NPUKernelCreator<SliceNPUKernel>) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,42 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_SLICE_NPU_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_SLICE_NPU_H_ | |||||
| #include <vector> | |||||
| #include "src/ops/slice.h" | |||||
| #include "src/runtime/kernel/npu/npu_kernel.h" | |||||
| #include "include/graph/op/all_ops.h" | |||||
| namespace mindspore::kernel { | |||||
| class SliceNPUKernel : public NPUKernel { | |||||
| public: | |||||
| SliceNPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||||
| const mindspore::lite::PrimitiveC *primitive) | |||||
| : NPUKernel(parameter, inputs, outputs, ctx, primitive) {} | |||||
| ~SliceNPUKernel() override; | |||||
| int IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||||
| OpParameter *opParameter) override; | |||||
| int SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||||
| const std::vector<ge::Operator *> &npu_inputs) override; | |||||
| ge::Operator *GetNPUOp() override; | |||||
| private: | |||||
| hiai::op::Slice *op_ = nullptr; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_SLICE_NPU_H_ | |||||
| @@ -0,0 +1,70 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/npu/split_npu.h" | |||||
| #include <memory> | |||||
| #include "src/kernel_registry.h" | |||||
| #include "src/runtime/agent/npu/npu_converter_utils.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kNPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::schema::PrimitiveType_Split; | |||||
| namespace mindspore::kernel { | |||||
| int SplitNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||||
| OpParameter *opParameter) { | |||||
| return RET_OK; | |||||
| } | |||||
| int SplitNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||||
| const std::vector<ge::Operator *> &npu_inputs) { | |||||
| op_ = new (std::nothrow) hiai::op::SplitV(name_); | |||||
| if (op_ == nullptr) { | |||||
| MS_LOG(ERROR) << name_ << " op is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| int size = size_splits_.size(); | |||||
| ge::TensorDesc size_splits_tensor_desc(ge::Shape({size}), ge::FORMAT_NCHW, ge::DT_INT32); | |||||
| ge::TensorPtr size_splits_tensor = std::make_shared<hiai::Tensor>(size_splits_tensor_desc); | |||||
| size_splits_tensor->SetData(reinterpret_cast<uint8_t *>(size_splits_.data()), size * sizeof(int)); | |||||
| auto size_splits = new hiai::op::Const(name_ + "_size"); | |||||
| size_splits->set_attr_value(size_splits_tensor); | |||||
| ge::TensorDesc split_dim_tensor_desc(ge::Shape({1}), ge::FORMAT_NCHW, ge::DT_INT32); | |||||
| ge::TensorPtr split_dim_tensor = std::make_shared<hiai::Tensor>(split_dim_tensor_desc); | |||||
| vector<int32_t> split_dim_data_value = {split_dim_}; | |||||
| split_dim_tensor->SetData(reinterpret_cast<uint8_t *>(split_dim_data_value.data()), 1 * sizeof(int)); | |||||
| auto split_dim = new hiai::op::Const(name_ + "_dim"); | |||||
| split_dim->set_attr_value(split_dim_tensor); | |||||
| op_->set_input_x(*npu_inputs[0]); | |||||
| op_->set_attr_num_split(num_split_); | |||||
| op_->set_input_split_dim(*split_dim); | |||||
| op_->set_input_size_splits(*size_splits); | |||||
| op_->create_dynamic_output_y(num_split_); | |||||
| return RET_OK; | |||||
| } | |||||
| ge::Operator *mindspore::kernel::SplitNPUKernel::GetNPUOp() { return this->op_; } | |||||
| SplitNPUKernel::~SplitNPUKernel() { | |||||
| if (op_ != nullptr) { | |||||
| delete op_; | |||||
| op_ = nullptr; | |||||
| } | |||||
| } | |||||
| REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_Split, NPUKernelCreator<SplitNPUKernel>) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,50 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_SPLIT_NPU_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_SPLIT_NPU_H_ | |||||
| #include <vector> | |||||
| #include "src/ops/split.h" | |||||
| #include "src/runtime/kernel/npu/npu_kernel.h" | |||||
| #include "include/graph/op/all_ops.h" | |||||
| namespace mindspore::kernel { | |||||
| class SplitNPUKernel : public NPUKernel { | |||||
| public: | |||||
| SplitNPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||||
| const mindspore::lite::PrimitiveC *primitive) | |||||
| : NPUKernel(parameter, inputs, outputs, ctx, primitive) { | |||||
| auto split = reinterpret_cast<const mindspore::lite::Split *>(primitive); | |||||
| num_split_ = split->GetNumberSplit(); | |||||
| size_splits_ = split->GetSizeSplit(); | |||||
| split_dim_ = split->GetSplitDim(); | |||||
| } | |||||
| ~SplitNPUKernel() override; | |||||
| int IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||||
| OpParameter *opParameter) override; | |||||
| int SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||||
| const std::vector<ge::Operator *> &npu_inputs) override; | |||||
| ge::Operator *GetNPUOp() override; | |||||
| private: | |||||
| hiai::op::SplitV *op_ = nullptr; | |||||
| int num_split_; | |||||
| std::vector<int> size_splits_; | |||||
| int split_dim_; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_SPLIT_NPU_H_ | |||||
| @@ -0,0 +1,62 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/npu/transpose_npu.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "src/runtime/agent/npu/npu_converter_utils.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kNPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::schema::PrimitiveType_Nchw2Nhwc; | |||||
| using mindspore::schema::PrimitiveType_Nhwc2Nchw; | |||||
| using mindspore::schema::PrimitiveType_Transpose; | |||||
| namespace mindspore::kernel { | |||||
| int TransposeNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||||
| OpParameter *opParameter) { | |||||
| if (conjugate_) { | |||||
| MS_LOG(ERROR) << "Unsupported conjugate transpose."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int TransposeNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, | |||||
| const std::vector<ge::Operator *> &npu_inputs) { | |||||
| op_ = new (std::nothrow) hiai::op::Permute(name_); | |||||
| if (op_ == nullptr) { | |||||
| MS_LOG(ERROR) << name_ << " op is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| op_->set_input_x(*npu_inputs[0]); | |||||
| op_->set_attr_order(perm_); | |||||
| return RET_OK; | |||||
| } | |||||
| ge::Operator *mindspore::kernel::TransposeNPUKernel::GetNPUOp() { return this->op_; } | |||||
| TransposeNPUKernel::~TransposeNPUKernel() { | |||||
| if (op_ != nullptr) { | |||||
| delete op_; | |||||
| op_ = nullptr; | |||||
| } | |||||
| } | |||||
| REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_Transpose, NPUKernelCreator<TransposeNPUKernel>) | |||||
| // REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_Nhwc2Nchw, NPUKernelCreator<TransposeNPUKernel>) | |||||
| // REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_Nchw2Nhwc, NPUKernelCreator<TransposeNPUKernel>) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,56 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_TRANSPOSE_NPU_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_TRANSPOSE_NPU_H_ | |||||
| #include <vector> | |||||
| #include "nnacl/transpose.h" | |||||
| #include "src/runtime/kernel/npu/npu_kernel.h" | |||||
| #include "include/graph/op/all_ops.h" | |||||
| namespace mindspore::kernel { | |||||
| class TransposeNPUKernel : public NPUKernel { | |||||
| public: | |||||
| TransposeNPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||||
| const mindspore::lite::PrimitiveC *primitive) | |||||
| : NPUKernel(parameter, inputs, outputs, ctx, primitive) { | |||||
| if (primitive->Type() == schema::PrimitiveType_Transpose) { | |||||
| auto transpose_parameter = reinterpret_cast<TransposeParameter *>(parameter); | |||||
| conjugate_ = transpose_parameter->conjugate_; | |||||
| for (int i = 0; i < transpose_parameter->num_axes_; i++) { | |||||
| perm_.push_back(transpose_parameter->perm_[i]); | |||||
| } | |||||
| } else if (primitive->Type() == schema::PrimitiveType_Nchw2Nhwc) { | |||||
| perm_ = {0, 2, 3, 1}; | |||||
| } else if (primitive->Type() == schema::PrimitiveType_Nhwc2Nchw) { | |||||
| perm_ = {0, 3, 1, 2}; | |||||
| } | |||||
| } | |||||
| ~TransposeNPUKernel() override; | |||||
| int IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||||
| OpParameter *opParameter) override; | |||||
| int SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||||
| const std::vector<ge::Operator *> &npu_inputs) override; | |||||
| ge::Operator *GetNPUOp() override; | |||||
| private: | |||||
| hiai::op::Permute *op_ = nullptr; | |||||
| std::vector<int64_t> perm_; | |||||
| bool conjugate_ = false; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_TRANSPOSE_NPU_H_ | |||||
| @@ -0,0 +1,66 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/npu/unsqueeze_npu.h" | |||||
| #include <memory> | |||||
| #include "src/kernel_registry.h" | |||||
| #include "src/runtime/agent/npu/npu_converter_utils.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kNPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::schema::PrimitiveType_Unsqueeze; | |||||
| namespace mindspore::kernel { | |||||
| int UnsqueezeNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||||
| OpParameter *opParameter) { | |||||
| if (inputs[0]->shape().size() > 3) { | |||||
| MS_LOG(WARNING) << "The dimension of output not support bigger than 4."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int UnsqueezeNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, | |||||
| const std::vector<ge::Operator *> &npu_inputs) { | |||||
| op_ = new (std::nothrow) hiai::op::ExpandDims(name_); | |||||
| if (op_ == nullptr) { | |||||
| MS_LOG(ERROR) << name_ << " op is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| int size = axis_.size(); | |||||
| ge::TensorDesc desc(ge::Shape({size}), ge::FORMAT_NCHW, ge::DT_INT32); | |||||
| ge::TensorPtr tensor = std::make_shared<hiai::Tensor>(desc); | |||||
| tensor->SetData(reinterpret_cast<uint8_t *>(axis_.data()), size * sizeof(int)); | |||||
| auto axis = new hiai::op::Const(name_ + "_axis"); | |||||
| axis->set_attr_value(tensor); | |||||
| op_->set_input_x(*npu_inputs[0]); | |||||
| op_->set_input_axis(*axis); | |||||
| return RET_OK; | |||||
| } | |||||
| ge::Operator *mindspore::kernel::UnsqueezeNPUKernel::GetNPUOp() { return this->op_; } | |||||
| UnsqueezeNPUKernel::~UnsqueezeNPUKernel() { | |||||
| if (op_ != nullptr) { | |||||
| delete op_; | |||||
| op_ = nullptr; | |||||
| } | |||||
| } | |||||
| REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_Unsqueeze, NPUKernelCreator<UnsqueezeNPUKernel>) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,46 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_UNSQUEEZE_NPU_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_UNSQUEEZE_NPU_H_ | |||||
| #include <vector> | |||||
| #include "src/ops/unsqueeze.h" | |||||
| #include "src/runtime/kernel/npu/npu_kernel.h" | |||||
| #include "include/graph/op/all_ops.h" | |||||
| namespace mindspore::kernel { | |||||
| class UnsqueezeNPUKernel : public NPUKernel { | |||||
| public: | |||||
| UnsqueezeNPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||||
| const mindspore::lite::PrimitiveC *primitive) | |||||
| : NPUKernel(parameter, inputs, outputs, ctx, primitive) { | |||||
| auto unsqueeze = reinterpret_cast<const mindspore::lite::Unsqueeze *>(primitive); | |||||
| axis_ = unsqueeze->GetAxis(); | |||||
| } | |||||
| ~UnsqueezeNPUKernel() override; | |||||
| int IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||||
| OpParameter *opParameter) override; | |||||
| int SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||||
| const std::vector<ge::Operator *> &npu_inputs) override; | |||||
| ge::Operator *GetNPUOp() override; | |||||
| private: | |||||
| hiai::op::ExpandDims *op_ = nullptr; | |||||
| vector<int> axis_; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_UNSQUEEZE_NPU_H_ | |||||