| @@ -150,6 +150,7 @@ int Split::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu | |||
| for (size_t j = 0; j < size_splits_.size() - 1; ++j) { | |||
| split_dim_i -= size_splits_[j]; | |||
| } | |||
| size_splits_[i] = split_dim_i; | |||
| } else { | |||
| split_dim_i = size_splits_[i]; | |||
| } | |||
| @@ -24,6 +24,7 @@ ge::Shape ConverterToNPUShape(const std::vector<int> &src_shape) { | |||
| } | |||
| return ge::Shape({shapes}); | |||
| } | |||
| ge::Format ConverterToNPUFormat(schema::Format format) { | |||
| ge::Format ge_format; | |||
| switch (format) { | |||
| @@ -74,13 +75,14 @@ ge::DataType ConverterToNPUDataType(TypeId type_id) { | |||
| } | |||
| return data_type; | |||
| } | |||
| hiai::op::Data *ConverterToNPUData(Tensor *src, const std::string &name) { | |||
| auto data = new (std::nothrow) hiai::op::Data(name); | |||
| if (data == nullptr) { | |||
| MS_LOG(ERROR) << "new data failed."; | |||
| return data; | |||
| } | |||
| ge::TensorDesc tensor_desc(ConverterToNPUShape(src->shape()), ge::FORMAT_NCHW, | |||
| ge::TensorDesc tensor_desc(ConverterToNPUShape(src->shape()), ConverterToNPUFormat(src->format()), | |||
| ConverterToNPUDataType(src->data_type())); | |||
| data->update_input_desc_x(tensor_desc); | |||
| return data; | |||
| @@ -92,7 +94,7 @@ std::shared_ptr<ge::Tensor> ConverterToNPUTensor(Tensor *src) { | |||
| MS_LOG(ERROR) << "new ge_tensor failed."; | |||
| return ge_tensor; | |||
| } | |||
| ge::TensorDesc tensor_desc(ConverterToNPUShape(src->shape()), ge::FORMAT_NCHW, | |||
| ge::TensorDesc tensor_desc(ConverterToNPUShape(src->shape()), ConverterToNPUFormat(src->format()), | |||
| ConverterToNPUDataType(src->data_type())); | |||
| ge_tensor->SetTensorDesc(tensor_desc); | |||
| @@ -102,62 +104,7 @@ std::shared_ptr<ge::Tensor> ConverterToNPUTensor(Tensor *src) { | |||
| } | |||
| return ge_tensor; | |||
| } | |||
| /* | |||
| * mode : Activation mode, with options as follows: | |||
| * 0 : Sigmoid | |||
| * 1 : ReLU | |||
| * 2 : Tanh | |||
| * 3 : Clipped ReLU | |||
| * 4 : ELU | |||
| * 5 : PReLU | |||
| * 6 : Abs | |||
| * 7 : Relu1 | |||
| * 8 : Softsign | |||
| * 9 : Softplus | |||
| * 10 : Hardsigmoid | |||
| * 11 : Threshold ReLU | |||
| * 12 : Selu | |||
| * 13 : Linear | |||
| * 14 : Relu6 | |||
| * 15 : GeLU. | |||
| */ | |||
| int ConverterToNPUActMode(schema::ActivationType type) { | |||
| switch (type) { | |||
| case schema::ActivationType_NO_ACTIVATION: | |||
| return -1; | |||
| case schema::ActivationType_SIGMOID: | |||
| return 0; | |||
| case schema::ActivationType_RELU: | |||
| return 1; | |||
| case schema::ActivationType_TANH: | |||
| return 2; | |||
| case schema::ActivationType_ELU: | |||
| return 4; | |||
| case schema::ActivationType_LEAKY_RELU: | |||
| return 5; | |||
| case schema::ActivationType_ABS: | |||
| return 6; | |||
| case schema::ActivationType_RELU1: | |||
| return 7; | |||
| case schema::ActivationType_SOFTSIGN: | |||
| return 8; | |||
| case schema::ActivationType_SOFTPLUS: | |||
| return 9; | |||
| case schema::ActivationType_HSIGMOID: | |||
| return 10; | |||
| case schema::ActivationType_THRESHOLDRELU: | |||
| return 11; | |||
| case schema::ActivationType_SELU: | |||
| return 12; | |||
| case schema::ActivationType_LINEAR: | |||
| return 13; | |||
| case schema::ActivationType_RELU6: | |||
| return 14; | |||
| default: | |||
| MS_LOG(ERROR) << "Unsupport activation type to NPU." << type; | |||
| return -1; | |||
| } | |||
| } | |||
| // mode : Either 0 (product), 1 (sum), 2 (max), 3 (mean). Defaults to 1 (sum). | |||
| int ConverterToNPUEltwiseMode(schema::EltwiseMode mode) { | |||
| int mode_num = 1; | |||
| @@ -53,6 +53,7 @@ int NPUExecutor::Run(std::vector<Tensor *> &in_tensors, std::vector<Tensor *> &o | |||
| for (int i = 0; i < npu_output_tensors_.size(); ++i) { | |||
| memcpy(out_tensors[i]->MutableData(), npu_output_tensors_[i]->GetBuffer(), npu_output_tensors_[i]->GetSize()); | |||
| out_tensors[i]->ResetRefCount(); | |||
| } | |||
| return RET_OK; | |||
| @@ -83,16 +83,22 @@ int SubGraphNpuKernel::BuildNPUInputOp() { | |||
| for (auto in_tensor : node->in_tensors()) { | |||
| if (IsSubGraphInputTensor(in_tensor)) { | |||
| auto tensor_name = node->name() + "_" + std::to_string(count++); | |||
| auto shape = in_tensor->shape(); | |||
| hiai::op::Data *data; | |||
| if (trans_nodes.find(node->Type()) != trans_nodes.end()) { | |||
| in_tensor->set_shape({shape[0], shape[3], shape[1], shape[2]}); | |||
| auto shape = in_tensor->shape(); | |||
| data = new (std::nothrow) hiai::op::Data(tensor_name); | |||
| if (data == nullptr) { | |||
| MS_LOG(ERROR) << "New data failed."; | |||
| return RET_ERROR; | |||
| } | |||
| ge::TensorDesc tensor_desc(lite::ConverterToNPUShape({shape[0], shape[3], shape[1], shape[2]}), | |||
| ge::FORMAT_NCHW, lite::ConverterToNPUDataType(in_tensor->data_type())); | |||
| data->update_input_desc_x(tensor_desc); | |||
| } else { | |||
| data = mindspore::lite::ConverterToNPUData(in_tensor, tensor_name); | |||
| } | |||
| auto data = mindspore::lite::ConverterToNPUData(in_tensor, tensor_name); | |||
| subgraph_input_op_.push_back(*data); | |||
| node_input_op.push_back(data); | |||
| if (trans_nodes.find(node->Type()) != trans_nodes.end()) { | |||
| in_tensor->set_shape(shape); | |||
| } | |||
| continue; | |||
| } | |||
| @@ -120,13 +126,11 @@ int SubGraphNpuKernel::BuildNPUInputOp() { | |||
| // weight tensor | |||
| if (is_weight_tensor) { | |||
| if (!(node->Type() == schema::PrimitiveType_Conv2D || node->Type() == schema::PrimitiveType_DeConv2D || | |||
| node->Type() == schema::PrimitiveType_DepthwiseConv2D || | |||
| node->Type() == schema::PrimitiveType_DeDepthwiseConv2D)) { | |||
| if (trans_nodes.find(node->Type()) == trans_nodes.end()) { | |||
| auto name = node->name() + "_" + std::to_string(count++); | |||
| auto weight_const = new (std::nothrow) hiai::op::Const(node->name() + "_" + std::to_string(count++)); | |||
| if (weight_const == nullptr) { | |||
| MS_LOG(ERROR) << "new weight const failed."; | |||
| MS_LOG(ERROR) << "New weight const failed."; | |||
| return RET_ERROR; | |||
| } | |||
| auto weight_tensor = mindspore::lite::ConverterToNPUTensor(in_tensor); | |||
| @@ -24,7 +24,7 @@ using mindspore::schema::PrimitiveType_Concat; | |||
| namespace mindspore::kernel { | |||
| int ConcatNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| OpParameter *opParameter) { | |||
| return RET_OK; | |||
| return RET_ERROR; | |||
| } | |||
| int ConcatNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| @@ -25,7 +25,7 @@ using mindspore::schema::PrimitiveType_DepthwiseConv2D; | |||
| namespace mindspore::kernel { | |||
| int ConvolutionDepthwiseNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter) { | |||
| return RET_OK; | |||
| return RET_ERROR; | |||
| } | |||
| int ConvolutionDepthwiseNPUKernel::SetConvDwParam() { | |||
| @@ -24,7 +24,7 @@ using mindspore::schema::PrimitiveType_Conv2D; | |||
| namespace mindspore::kernel { | |||
| int ConvolutionNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter) { | |||
| return RET_OK; | |||
| return RET_ERROR; | |||
| } | |||
| int ConvolutionNPUKernel::SetConvParam() { | |||
| @@ -0,0 +1,51 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/npu/matmul_npu.h" | |||
| #include "src/kernel_registry.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kNPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::schema::PrimitiveType_MatMul; | |||
| namespace mindspore::kernel { | |||
| int MatMulNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| OpParameter *opParameter) { | |||
| return RET_OK; | |||
| } | |||
| int MatMulNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| const std::vector<ge::Operator *> &npu_inputs) { | |||
| op_ = new (std::nothrow) hiai::op::MatMul(name_); | |||
| op_->set_input_x1(*npu_inputs[0]); | |||
| op_->set_input_x2(*npu_inputs[1]); | |||
| op_->set_attr_transpose_x1(a_transpose_); | |||
| op_->set_attr_transpose_x2(b_transpose_); | |||
| return RET_OK; | |||
| } | |||
| ge::Operator *mindspore::kernel::MatMulNPUKernel::GetNPUOp() { return this->op_; } | |||
| MatMulNPUKernel::~MatMulNPUKernel() { | |||
| if (op_ != nullptr) { | |||
| delete op_; | |||
| op_ = nullptr; | |||
| } | |||
| } | |||
| REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_MatMul, NPUKernelCreator<MatMulNPUKernel>) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_MATMUL_NPU_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_MATMUL_NPU_H_ | |||
| #include <vector> | |||
| #include "nnacl/matmul_parameter.h" | |||
| #include "src/runtime/kernel/npu/npu_kernel.h" | |||
| #include "nnacl/softmax_parameter.h" | |||
| #include "include/graph/op/all_ops.h" | |||
| namespace mindspore::kernel { | |||
| class MatMulNPUKernel : public NPUKernel { | |||
| public: | |||
| MatMulNPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : NPUKernel(parameter, inputs, outputs, ctx, primitive) { | |||
| auto matmul_parameter = reinterpret_cast<MatMulParameter *>(parameter); | |||
| a_transpose_ = matmul_parameter->a_transpose_; | |||
| b_transpose_ = matmul_parameter->b_transpose_; | |||
| } | |||
| ~MatMulNPUKernel() override; | |||
| int IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| OpParameter *opParameter) override; | |||
| int SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| const std::vector<ge::Operator *> &npu_inputs) override; | |||
| ge::Operator *GetNPUOp() override; | |||
| private: | |||
| hiai::op::MatMul *op_ = nullptr; | |||
| bool a_transpose_ = false; | |||
| bool b_transpose_ = false; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_MATMUL_NPU_H_ | |||
| @@ -0,0 +1,73 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/npu/pad_npu.h" | |||
| #include <memory> | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/agent/npu/npu_converter_utils.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kNPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::schema::PrimitiveType_Pad; | |||
| namespace mindspore::kernel { | |||
| int PadNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| OpParameter *opParameter) { | |||
| if (padding_mode_ != schema::PaddingMode_CONSTANT) { | |||
| MS_LOG(WARNING) << "NPU only support CONSTANT padding mode"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int PadNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| const std::vector<ge::Operator *> &npu_inputs) { | |||
| op_ = new (std::nothrow) hiai::op::PadV2(name_); | |||
| if (op_ == nullptr) { | |||
| MS_LOG(ERROR) << name_ << " op is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| int size = static_cast<int>(paddings_.size() / 2); | |||
| ge::TensorDesc padding_tensor_desc(ge::Shape({size, 2}), ge::FORMAT_NCHW, ge::DT_INT32); | |||
| ge::TensorPtr padding_tensor = std::make_shared<hiai::Tensor>(padding_tensor_desc); | |||
| padding_tensor->SetData(reinterpret_cast<uint8_t *>(paddings_.data()), size * sizeof(int)); | |||
| auto paddings = new hiai::op::Const(name_ + "paddings"); | |||
| paddings->set_attr_value(padding_tensor); | |||
| ge::TensorDesc constant_values_tensor_desc(ge::Shape({1}), ge::FORMAT_NCHW, ge::DT_FLOAT); | |||
| ge::TensorPtr constant_values_tensor = std::make_shared<hiai::Tensor>(constant_values_tensor_desc); | |||
| vector<float> constant_values_data_value = {constant_value_}; | |||
| constant_values_tensor->SetData(reinterpret_cast<uint8_t *>(constant_values_data_value.data()), 1 * sizeof(float)); | |||
| auto constant = new hiai::op::Const(name_ + "constant"); | |||
| constant->set_attr_value(constant_values_tensor); | |||
| op_->set_input_x(*npu_inputs[0]); | |||
| op_->set_input_constant_values(*constant); | |||
| op_->set_input_paddings(*paddings); | |||
| return RET_OK; | |||
| } | |||
| ge::Operator *mindspore::kernel::PadNPUKernel::GetNPUOp() { return this->op_; } | |||
| PadNPUKernel::~PadNPUKernel() { | |||
| if (op_ != nullptr) { | |||
| delete op_; | |||
| op_ = nullptr; | |||
| } | |||
| } | |||
| REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_Pad, NPUKernelCreator<PadNPUKernel>) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,51 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_PAD_NPU_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_PAD_NPU_H_ | |||
| #include <vector> | |||
| #include "nnacl/pad_parameter.h" | |||
| #include "src/ops/pad.h" | |||
| #include "src/runtime/kernel/npu/npu_kernel.h" | |||
| #include "include/graph/op/all_ops.h" | |||
| namespace mindspore::kernel { | |||
| class PadNPUKernel : public NPUKernel { | |||
| public: | |||
| PadNPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : NPUKernel(parameter, inputs, outputs, ctx, primitive) { | |||
| auto pad = reinterpret_cast<const mindspore::lite::Pad *>(primitive); | |||
| constant_value_ = pad->GetConstantValue(); | |||
| paddings_ = pad->GetPaddings(); | |||
| padding_mode_ = pad->GetPaddingMode(); | |||
| } | |||
| ~PadNPUKernel() override; | |||
| int IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| OpParameter *opParameter) override; | |||
| int SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| const std::vector<ge::Operator *> &npu_inputs) override; | |||
| ge::Operator *GetNPUOp() override; | |||
| private: | |||
| hiai::op::PadV2 *op_ = nullptr; | |||
| std::vector<int> paddings_; | |||
| int padding_mode_; | |||
| float constant_value_; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_PAD_NPU_H_ | |||
| @@ -24,7 +24,7 @@ using mindspore::schema::PrimitiveType_Pooling; | |||
| namespace mindspore::kernel { | |||
| int PoolingNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| OpParameter *opParameter) { | |||
| return RET_OK; | |||
| return RET_ERROR; | |||
| } | |||
| int PoolingNPUKernel::SetPoolingParam() { | |||
| @@ -0,0 +1,54 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/npu/slice_npu.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/agent/npu/npu_converter_utils.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kNPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::schema::PrimitiveType_Slice; | |||
| namespace mindspore::kernel { | |||
| int SliceNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| OpParameter *opParameter) { | |||
| return RET_OK; | |||
| } | |||
| int SliceNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| const std::vector<ge::Operator *> &npu_inputs) { | |||
| op_ = new (std::nothrow) hiai::op::Slice(name_); | |||
| if (op_ == nullptr) { | |||
| MS_LOG(ERROR) << name_ << " op is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| op_->set_input_x(*npu_inputs[0]); | |||
| op_->set_input_offsets(*npu_inputs[1]); | |||
| op_->set_input_size(*npu_inputs[2]); | |||
| return RET_OK; | |||
| } | |||
| ge::Operator *mindspore::kernel::SliceNPUKernel::GetNPUOp() { return this->op_; } | |||
| SliceNPUKernel::~SliceNPUKernel() { | |||
| if (op_ != nullptr) { | |||
| delete op_; | |||
| op_ = nullptr; | |||
| } | |||
| } | |||
| REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_Slice, NPUKernelCreator<SliceNPUKernel>) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_SLICE_NPU_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_SLICE_NPU_H_ | |||
| #include <vector> | |||
| #include "src/ops/slice.h" | |||
| #include "src/runtime/kernel/npu/npu_kernel.h" | |||
| #include "include/graph/op/all_ops.h" | |||
| namespace mindspore::kernel { | |||
| class SliceNPUKernel : public NPUKernel { | |||
| public: | |||
| SliceNPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : NPUKernel(parameter, inputs, outputs, ctx, primitive) {} | |||
| ~SliceNPUKernel() override; | |||
| int IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| OpParameter *opParameter) override; | |||
| int SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| const std::vector<ge::Operator *> &npu_inputs) override; | |||
| ge::Operator *GetNPUOp() override; | |||
| private: | |||
| hiai::op::Slice *op_ = nullptr; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_SLICE_NPU_H_ | |||
| @@ -0,0 +1,70 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/npu/split_npu.h" | |||
| #include <memory> | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/agent/npu/npu_converter_utils.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kNPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::schema::PrimitiveType_Split; | |||
| namespace mindspore::kernel { | |||
| int SplitNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| OpParameter *opParameter) { | |||
| return RET_OK; | |||
| } | |||
| int SplitNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| const std::vector<ge::Operator *> &npu_inputs) { | |||
| op_ = new (std::nothrow) hiai::op::SplitV(name_); | |||
| if (op_ == nullptr) { | |||
| MS_LOG(ERROR) << name_ << " op is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| int size = size_splits_.size(); | |||
| ge::TensorDesc size_splits_tensor_desc(ge::Shape({size}), ge::FORMAT_NCHW, ge::DT_INT32); | |||
| ge::TensorPtr size_splits_tensor = std::make_shared<hiai::Tensor>(size_splits_tensor_desc); | |||
| size_splits_tensor->SetData(reinterpret_cast<uint8_t *>(size_splits_.data()), size * sizeof(int)); | |||
| auto size_splits = new hiai::op::Const(name_ + "_size"); | |||
| size_splits->set_attr_value(size_splits_tensor); | |||
| ge::TensorDesc split_dim_tensor_desc(ge::Shape({1}), ge::FORMAT_NCHW, ge::DT_INT32); | |||
| ge::TensorPtr split_dim_tensor = std::make_shared<hiai::Tensor>(split_dim_tensor_desc); | |||
| vector<int32_t> split_dim_data_value = {split_dim_}; | |||
| split_dim_tensor->SetData(reinterpret_cast<uint8_t *>(split_dim_data_value.data()), 1 * sizeof(int)); | |||
| auto split_dim = new hiai::op::Const(name_ + "_dim"); | |||
| split_dim->set_attr_value(split_dim_tensor); | |||
| op_->set_input_x(*npu_inputs[0]); | |||
| op_->set_attr_num_split(num_split_); | |||
| op_->set_input_split_dim(*split_dim); | |||
| op_->set_input_size_splits(*size_splits); | |||
| op_->create_dynamic_output_y(num_split_); | |||
| return RET_OK; | |||
| } | |||
| ge::Operator *mindspore::kernel::SplitNPUKernel::GetNPUOp() { return this->op_; } | |||
| SplitNPUKernel::~SplitNPUKernel() { | |||
| if (op_ != nullptr) { | |||
| delete op_; | |||
| op_ = nullptr; | |||
| } | |||
| } | |||
| REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_Split, NPUKernelCreator<SplitNPUKernel>) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,50 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_SPLIT_NPU_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_SPLIT_NPU_H_ | |||
| #include <vector> | |||
| #include "src/ops/split.h" | |||
| #include "src/runtime/kernel/npu/npu_kernel.h" | |||
| #include "include/graph/op/all_ops.h" | |||
| namespace mindspore::kernel { | |||
| class SplitNPUKernel : public NPUKernel { | |||
| public: | |||
| SplitNPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : NPUKernel(parameter, inputs, outputs, ctx, primitive) { | |||
| auto split = reinterpret_cast<const mindspore::lite::Split *>(primitive); | |||
| num_split_ = split->GetNumberSplit(); | |||
| size_splits_ = split->GetSizeSplit(); | |||
| split_dim_ = split->GetSplitDim(); | |||
| } | |||
| ~SplitNPUKernel() override; | |||
| int IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| OpParameter *opParameter) override; | |||
| int SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| const std::vector<ge::Operator *> &npu_inputs) override; | |||
| ge::Operator *GetNPUOp() override; | |||
| private: | |||
| hiai::op::SplitV *op_ = nullptr; | |||
| int num_split_; | |||
| std::vector<int> size_splits_; | |||
| int split_dim_; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_SPLIT_NPU_H_ | |||
| @@ -0,0 +1,62 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/npu/transpose_npu.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/agent/npu/npu_converter_utils.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kNPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::schema::PrimitiveType_Nchw2Nhwc; | |||
| using mindspore::schema::PrimitiveType_Nhwc2Nchw; | |||
| using mindspore::schema::PrimitiveType_Transpose; | |||
| namespace mindspore::kernel { | |||
| int TransposeNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| OpParameter *opParameter) { | |||
| if (conjugate_) { | |||
| MS_LOG(ERROR) << "Unsupported conjugate transpose."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int TransposeNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, | |||
| const std::vector<ge::Operator *> &npu_inputs) { | |||
| op_ = new (std::nothrow) hiai::op::Permute(name_); | |||
| if (op_ == nullptr) { | |||
| MS_LOG(ERROR) << name_ << " op is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| op_->set_input_x(*npu_inputs[0]); | |||
| op_->set_attr_order(perm_); | |||
| return RET_OK; | |||
| } | |||
| ge::Operator *mindspore::kernel::TransposeNPUKernel::GetNPUOp() { return this->op_; } | |||
| TransposeNPUKernel::~TransposeNPUKernel() { | |||
| if (op_ != nullptr) { | |||
| delete op_; | |||
| op_ = nullptr; | |||
| } | |||
| } | |||
| REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_Transpose, NPUKernelCreator<TransposeNPUKernel>) | |||
| // REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_Nhwc2Nchw, NPUKernelCreator<TransposeNPUKernel>) | |||
| // REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_Nchw2Nhwc, NPUKernelCreator<TransposeNPUKernel>) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,56 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_TRANSPOSE_NPU_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_TRANSPOSE_NPU_H_ | |||
| #include <vector> | |||
| #include "nnacl/transpose.h" | |||
| #include "src/runtime/kernel/npu/npu_kernel.h" | |||
| #include "include/graph/op/all_ops.h" | |||
| namespace mindspore::kernel { | |||
| class TransposeNPUKernel : public NPUKernel { | |||
| public: | |||
| TransposeNPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : NPUKernel(parameter, inputs, outputs, ctx, primitive) { | |||
| if (primitive->Type() == schema::PrimitiveType_Transpose) { | |||
| auto transpose_parameter = reinterpret_cast<TransposeParameter *>(parameter); | |||
| conjugate_ = transpose_parameter->conjugate_; | |||
| for (int i = 0; i < transpose_parameter->num_axes_; i++) { | |||
| perm_.push_back(transpose_parameter->perm_[i]); | |||
| } | |||
| } else if (primitive->Type() == schema::PrimitiveType_Nchw2Nhwc) { | |||
| perm_ = {0, 2, 3, 1}; | |||
| } else if (primitive->Type() == schema::PrimitiveType_Nhwc2Nchw) { | |||
| perm_ = {0, 3, 1, 2}; | |||
| } | |||
| } | |||
| ~TransposeNPUKernel() override; | |||
| int IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| OpParameter *opParameter) override; | |||
| int SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| const std::vector<ge::Operator *> &npu_inputs) override; | |||
| ge::Operator *GetNPUOp() override; | |||
| private: | |||
| hiai::op::Permute *op_ = nullptr; | |||
| std::vector<int64_t> perm_; | |||
| bool conjugate_ = false; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_TRANSPOSE_NPU_H_ | |||
| @@ -0,0 +1,66 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/npu/unsqueeze_npu.h" | |||
| #include <memory> | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/agent/npu/npu_converter_utils.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kNPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::schema::PrimitiveType_Unsqueeze; | |||
| namespace mindspore::kernel { | |||
| int UnsqueezeNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| OpParameter *opParameter) { | |||
| if (inputs[0]->shape().size() > 3) { | |||
| MS_LOG(WARNING) << "The dimension of output not support bigger than 4."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int UnsqueezeNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, | |||
| const std::vector<ge::Operator *> &npu_inputs) { | |||
| op_ = new (std::nothrow) hiai::op::ExpandDims(name_); | |||
| if (op_ == nullptr) { | |||
| MS_LOG(ERROR) << name_ << " op is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| int size = axis_.size(); | |||
| ge::TensorDesc desc(ge::Shape({size}), ge::FORMAT_NCHW, ge::DT_INT32); | |||
| ge::TensorPtr tensor = std::make_shared<hiai::Tensor>(desc); | |||
| tensor->SetData(reinterpret_cast<uint8_t *>(axis_.data()), size * sizeof(int)); | |||
| auto axis = new hiai::op::Const(name_ + "_axis"); | |||
| axis->set_attr_value(tensor); | |||
| op_->set_input_x(*npu_inputs[0]); | |||
| op_->set_input_axis(*axis); | |||
| return RET_OK; | |||
| } | |||
| ge::Operator *mindspore::kernel::UnsqueezeNPUKernel::GetNPUOp() { return this->op_; } | |||
| UnsqueezeNPUKernel::~UnsqueezeNPUKernel() { | |||
| if (op_ != nullptr) { | |||
| delete op_; | |||
| op_ = nullptr; | |||
| } | |||
| } | |||
| REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_Unsqueeze, NPUKernelCreator<UnsqueezeNPUKernel>) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,46 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_UNSQUEEZE_NPU_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_UNSQUEEZE_NPU_H_ | |||
| #include <vector> | |||
| #include "src/ops/unsqueeze.h" | |||
| #include "src/runtime/kernel/npu/npu_kernel.h" | |||
| #include "include/graph/op/all_ops.h" | |||
| namespace mindspore::kernel { | |||
| class UnsqueezeNPUKernel : public NPUKernel { | |||
| public: | |||
| UnsqueezeNPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : NPUKernel(parameter, inputs, outputs, ctx, primitive) { | |||
| auto unsqueeze = reinterpret_cast<const mindspore::lite::Unsqueeze *>(primitive); | |||
| axis_ = unsqueeze->GetAxis(); | |||
| } | |||
| ~UnsqueezeNPUKernel() override; | |||
| int IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| OpParameter *opParameter) override; | |||
| int SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| const std::vector<ge::Operator *> &npu_inputs) override; | |||
| ge::Operator *GetNPUOp() override; | |||
| private: | |||
| hiai::op::ExpandDims *op_ = nullptr; | |||
| vector<int> axis_; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_UNSQUEEZE_NPU_H_ | |||