| @@ -21,20 +21,22 @@ | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| bool CheckFusion(kernel::LiteKernel *kernel) { | bool CheckFusion(kernel::LiteKernel *kernel) { | ||||
| auto pre_flag = | auto pre_flag = | ||||
| std::all_of(kernel->in_kernels().begin(), kernel->in_kernels().end(), [](const kernel::LiteKernel *kernel) { | |||||
| return kernel->Type() == schema::PrimitiveType_Nchw2Nhwc && kernel->out_kernels().size() == 1; | |||||
| std::all_of(kernel->in_kernels().begin(), kernel->in_kernels().end(), [](const kernel::LiteKernel *in_kernel) { | |||||
| return in_kernel->Type() == schema::PrimitiveType_Nchw2Nhwc && in_kernel->out_kernels().size() == 1; | |||||
| }); | }); | ||||
| if (!pre_flag) { | if (!pre_flag) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto post_flag = | |||||
| std::all_of(kernel->out_kernels().begin(), kernel->out_kernels().end(), [](const kernel::LiteKernel *kernel) { | |||||
| return kernel->Type() == schema::PrimitiveType_Nhwc2Nchw && kernel->in_kernels().size() == 1; | |||||
| }); | |||||
| auto post_flag = std::all_of( | |||||
| kernel->out_kernels().begin(), kernel->out_kernels().end(), | |||||
| [](const kernel::LiteKernel *out_kernel) { return out_kernel->Type() == schema::PrimitiveType_Nhwc2Nchw; }); | |||||
| return post_flag; | return post_flag; | ||||
| } | } | ||||
| bool CheckFormatFusion(kernel::LiteKernel *kernel) { | bool CheckFormatFusion(kernel::LiteKernel *kernel) { | ||||
| if (kernel->out_kernels().empty()) { | |||||
| return false; | |||||
| } | |||||
| if (kernel->Type() == schema::PrimitiveType_Nhwc2Nchw) { | if (kernel->Type() == schema::PrimitiveType_Nhwc2Nchw) { | ||||
| return std::all_of( | return std::all_of( | ||||
| kernel->out_kernels().begin(), kernel->out_kernels().end(), | kernel->out_kernels().begin(), kernel->out_kernels().end(), | ||||
| @@ -159,38 +161,26 @@ int TransFormAxis(int axis) { | |||||
| } | } | ||||
| } | } | ||||
| int NPUFusionPass::AddFusion(kernel::LiteKernel *kernel) { | |||||
| if (!CheckFusion(kernel)) { | |||||
| return RET_OK; | |||||
| } | |||||
| void NPUFusionPass::UpdateKernel(kernel::LiteKernel *kernel) { | |||||
| UpdatePreTensors(kernel); | UpdatePreTensors(kernel); | ||||
| UpdatePostTensors(kernel); | UpdatePostTensors(kernel); | ||||
| UpdatePreKernels(kernel); | UpdatePreKernels(kernel); | ||||
| UpdatePostKernels(kernel); | UpdatePostKernels(kernel); | ||||
| } | |||||
| int NPUFusionPass::CommonFusion(kernel::LiteKernel *kernel) { | |||||
| UpdateKernel(kernel); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int NPUFusionPass::ConcatFusion(kernel::LiteKernel *kernel) { | int NPUFusionPass::ConcatFusion(kernel::LiteKernel *kernel) { | ||||
| if (!CheckFusion(kernel)) { | |||||
| return RET_OK; | |||||
| } | |||||
| UpdatePreTensors(kernel); | |||||
| UpdatePostTensors(kernel); | |||||
| UpdatePreKernels(kernel); | |||||
| UpdatePostKernels(kernel); | |||||
| UpdateKernel(kernel); | |||||
| auto concat_param = reinterpret_cast<ConcatParameter *>(kernel->op_parameter()); | auto concat_param = reinterpret_cast<ConcatParameter *>(kernel->op_parameter()); | ||||
| concat_param->axis_ = TransFormAxis(concat_param->axis_); | concat_param->axis_ = TransFormAxis(concat_param->axis_); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int NPUFusionPass::FormatFusion(kernel::LiteKernel *kernel) { | int NPUFusionPass::FormatFusion(kernel::LiteKernel *kernel) { | ||||
| if (kernel->out_kernels().empty()) { | |||||
| return RET_OK; | |||||
| } | |||||
| if (!CheckFormatFusion(kernel)) { | |||||
| return RET_OK; | |||||
| } | |||||
| auto pre_kernel = kernel->in_kernels()[0]; | auto pre_kernel = kernel->in_kernels()[0]; | ||||
| auto in_tensor = kernel->in_tensors()[0]; | auto in_tensor = kernel->in_tensors()[0]; | ||||
| auto out_tensor = kernel->out_tensors()[0]; | auto out_tensor = kernel->out_tensors()[0]; | ||||
| @@ -237,17 +227,28 @@ int NPUFusionPass::FormatFusion(kernel::LiteKernel *kernel) { | |||||
| } | } | ||||
| int NPUFusionPass::Run() { | int NPUFusionPass::Run() { | ||||
| for (auto kernel : *kernels) { | |||||
| for (size_t i = 0; i < kernels->size(); i++) { | |||||
| auto kernel = (*kernels)[i]; | |||||
| if (kernel->Type() == schema::PrimitiveType_Nchw2Nhwc || kernel->Type() == schema::PrimitiveType_Nchw2Nhwc) { | |||||
| if (CheckFormatFusion(kernel)) { | |||||
| i--; | |||||
| FormatFusion(kernel); | |||||
| } | |||||
| continue; | |||||
| } | |||||
| if (!CheckFusion(kernel)) { | |||||
| continue; | |||||
| } | |||||
| switch (kernel->Type()) { | switch (kernel->Type()) { | ||||
| case schema::PrimitiveType_Concat: | case schema::PrimitiveType_Concat: | ||||
| i -= kernel->in_kernels().size(); | |||||
| ConcatFusion(kernel); | ConcatFusion(kernel); | ||||
| continue; | continue; | ||||
| case schema::PrimitiveType_Add: | case schema::PrimitiveType_Add: | ||||
| case schema::PrimitiveType_Activation: | case schema::PrimitiveType_Activation: | ||||
| AddFusion(kernel); | |||||
| continue; | |||||
| case schema::PrimitiveType_Nchw2Nhwc: | |||||
| FormatFusion(kernel); | |||||
| case schema::PrimitiveType_Eltwise: | |||||
| i -= kernel->in_kernels().size(); | |||||
| CommonFusion(kernel); | |||||
| continue; | continue; | ||||
| default: | default: | ||||
| continue; | continue; | ||||
| @@ -33,11 +33,12 @@ class NPUFusionPass : public NPUBasePass { | |||||
| int Run() override; | int Run() override; | ||||
| protected: | protected: | ||||
| void RemoveAndFreeKernel(kernel::LiteKernel *cur_kernel); | |||||
| void UpdatePreKernels(kernel::LiteKernel *kernel); | void UpdatePreKernels(kernel::LiteKernel *kernel); | ||||
| void UpdatePostKernels(kernel::LiteKernel *kernel); | void UpdatePostKernels(kernel::LiteKernel *kernel); | ||||
| void RemoveAndFreeKernel(kernel::LiteKernel *cur_kernel); | |||||
| void UpdateKernel(kernel::LiteKernel *kernel); | |||||
| int CommonFusion(kernel::LiteKernel *kernel); | |||||
| int ConcatFusion(kernel::LiteKernel *kernel); | int ConcatFusion(kernel::LiteKernel *kernel); | ||||
| int AddFusion(kernel::LiteKernel *kernel); | |||||
| int FormatFusion(kernel::LiteKernel *kernel); | int FormatFusion(kernel::LiteKernel *kernel); | ||||
| private: | private: | ||||
| @@ -21,7 +21,9 @@ namespace mindspore::lite { | |||||
| using kernel::KERNEL_ARCH::kNPU; | using kernel::KERNEL_ARCH::kNPU; | ||||
| enum InsertState { InsertNone, PreInsert, PostInsert }; | enum InsertState { InsertNone, PreInsert, PostInsert }; | ||||
| std::set<mindspore::schema::PrimitiveType> npu_insert_nodes = {schema::PrimitiveType_Concat, schema::PrimitiveType_Add}; | |||||
| std::set<mindspore::schema::PrimitiveType> npu_insert_nodes = {schema::PrimitiveType_Concat, schema::PrimitiveType_Add, | |||||
| schema::PrimitiveType_Eltwise, | |||||
| schema::PrimitiveType_Activation}; | |||||
| int GetInsertState(kernel::LiteKernel *kernel) { | int GetInsertState(kernel::LiteKernel *kernel) { | ||||
| if (npu_insert_nodes.find(kernel->Type()) == npu_insert_nodes.end()) { | if (npu_insert_nodes.find(kernel->Type()) == npu_insert_nodes.end()) { | ||||
| @@ -42,16 +44,18 @@ int GetInsertState(kernel::LiteKernel *kernel) { | |||||
| return InsertNone; | return InsertNone; | ||||
| } | } | ||||
| int NPUInsertTransformPass::InsertPreNode(const InnerContext *context, kernel::LiteKernel *cur_kernel, | |||||
| std::vector<kernel::LiteKernel *> *all_kernels, | |||||
| int NPUInsertTransformPass::InsertPreNode(const InnerContext *context, kernel::LiteKernel *kernel, | |||||
| std::vector<kernel::LiteKernel *> *trans_kernels, | |||||
| std::vector<Tensor *> *all_tensors) { | std::vector<Tensor *> *all_tensors) { | ||||
| for (auto kernel : cur_kernel->in_kernels()) { | |||||
| if (kernel->Type() == schema::PrimitiveType_Nchw2Nhwc) { | |||||
| for (auto in_kernel : kernel->in_kernels()) { | |||||
| if (in_kernel->Type() == schema::PrimitiveType_Nchw2Nhwc) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto nhwc_shape = cur_kernel->out_tensors()[0]->shape(); | |||||
| auto nhwc_shape = in_kernel->out_tensors()[0]->shape(); | |||||
| std::vector<int> nchw_shape = {nhwc_shape[0], nhwc_shape[3], nhwc_shape[1], nhwc_shape[2]}; | std::vector<int> nchw_shape = {nhwc_shape[0], nhwc_shape[3], nhwc_shape[1], nhwc_shape[2]}; | ||||
| auto nh2nc_tensor = new Tensor(kernel->out_tensors()[0]->data_type(), nchw_shape, schema::Format_NHWC, Tensor::VAR); | |||||
| auto nh2nc_tensor = | |||||
| new Tensor(in_kernel->out_tensors()[0]->data_type(), nchw_shape, schema::Format_NHWC, Tensor::VAR); | |||||
| std::vector<Tensor *> nh2nc_tensors = {nh2nc_tensor}; | std::vector<Tensor *> nh2nc_tensors = {nh2nc_tensor}; | ||||
| all_tensors->push_back(nh2nc_tensors[0]); | all_tensors->push_back(nh2nc_tensors[0]); | ||||
| @@ -59,34 +63,36 @@ int NPUInsertTransformPass::InsertPreNode(const InnerContext *context, kernel::L | |||||
| std::vector<Tensor *> nc2nh_tensors = {nc2nh_tensor}; | std::vector<Tensor *> nc2nh_tensors = {nc2nh_tensor}; | ||||
| all_tensors->push_back(nc2nh_tensors[0]); | all_tensors->push_back(nc2nh_tensors[0]); | ||||
| auto nh2nc_name = kernel->name() + "_nh2nc_" + std::to_string(total++); | |||||
| auto *nh2nc_kernel = NPUPassUtils::CreateNhwc2NchwKernel(kernel->out_tensors(), nh2nc_tensors, context, nh2nc_name); | |||||
| all_kernels->push_back(nh2nc_kernel); | |||||
| auto nh2nc_name = in_kernel->name() + "_nh2nc_" + std::to_string(total++); | |||||
| auto *nh2nc_kernel = | |||||
| NPUPassUtils::CreateNhwc2NchwKernel(in_kernel->out_tensors(), nh2nc_tensors, context, nh2nc_name); | |||||
| trans_kernels->push_back(nh2nc_kernel); | |||||
| insert_primitive_.push_back(nh2nc_kernel->GetPrimitive()); | insert_primitive_.push_back(nh2nc_kernel->GetPrimitive()); | ||||
| auto nc2nh_name = kernel->name() + "_nc2nh_" + std::to_string(total++); | |||||
| auto nc2nh_name = in_kernel->name() + "_nc2nh_" + std::to_string(total++); | |||||
| auto *nc2nh_kernel = NPUPassUtils::CreateNchw2NhwcKernel(nh2nc_tensors, nc2nh_tensors, context, nc2nh_name); | auto *nc2nh_kernel = NPUPassUtils::CreateNchw2NhwcKernel(nh2nc_tensors, nc2nh_tensors, context, nc2nh_name); | ||||
| all_kernels->push_back(nc2nh_kernel); | |||||
| trans_kernels->push_back(nc2nh_kernel); | |||||
| insert_primitive_.push_back(nc2nh_kernel->GetPrimitive()); | insert_primitive_.push_back(nc2nh_kernel->GetPrimitive()); | ||||
| NPUPassUtils::UpdateKernel(nh2nc_kernel, {kernel}, {nc2nh_kernel}, kernel->out_tensors(), nh2nc_tensors); | |||||
| NPUPassUtils::UpdateKernel(nc2nh_kernel, {nh2nc_kernel}, {cur_kernel}, nh2nc_tensors, nc2nh_tensors); | |||||
| NPUPassUtils::UpdateNH2NCTransNodePreKernel(kernel, nh2nc_kernel, cur_kernel); | |||||
| NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(kernel, nc2nh_kernel, cur_kernel); | |||||
| NPUPassUtils::UpdateKernel(nh2nc_kernel, {in_kernel}, {nc2nh_kernel}, in_kernel->out_tensors(), nh2nc_tensors); | |||||
| NPUPassUtils::UpdateKernel(nc2nh_kernel, {nh2nc_kernel}, {kernel}, nh2nc_tensors, nc2nh_tensors); | |||||
| NPUPassUtils::UpdateNH2NCTransNodePreKernel(in_kernel, nh2nc_kernel, kernel); | |||||
| NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(in_kernel, nc2nh_kernel, kernel); | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int NPUInsertTransformPass::InsertPostNode(const InnerContext *context, kernel::LiteKernel *cur_kernel, | |||||
| std::vector<kernel::LiteKernel *> *all_kernels, | |||||
| int NPUInsertTransformPass::InsertPostNode(const InnerContext *context, kernel::LiteKernel *kernel, | |||||
| std::vector<kernel::LiteKernel *> *trans_kernels, | |||||
| std::vector<Tensor *> *all_tensors) { | std::vector<Tensor *> *all_tensors) { | ||||
| for (auto out_kernel : cur_kernel->out_kernels()) { | |||||
| for (auto out_kernel : kernel->out_kernels()) { | |||||
| if (out_kernel->Type() == schema::PrimitiveType_Nhwc2Nchw) { | if (out_kernel->Type() == schema::PrimitiveType_Nhwc2Nchw) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| auto nhwc_shape = cur_kernel->out_tensors()[0]->shape(); | |||||
| auto nhwc_shape = kernel->out_tensors()[0]->shape(); | |||||
| std::vector<int> nchw_shape = {nhwc_shape[0], nhwc_shape[3], nhwc_shape[1], nhwc_shape[2]}; | std::vector<int> nchw_shape = {nhwc_shape[0], nhwc_shape[3], nhwc_shape[1], nhwc_shape[2]}; | ||||
| auto nh2nc_tensor = | |||||
| new Tensor(cur_kernel->out_tensors()[0]->data_type(), nchw_shape, schema::Format_NHWC, Tensor::VAR); | |||||
| auto nh2nc_tensor = new Tensor(kernel->out_tensors()[0]->data_type(), nchw_shape, schema::Format_NHWC, Tensor::VAR); | |||||
| std::vector<Tensor *> nh2nc_tensors = {nh2nc_tensor}; | std::vector<Tensor *> nh2nc_tensors = {nh2nc_tensor}; | ||||
| all_tensors->push_back(nh2nc_tensors[0]); | all_tensors->push_back(nh2nc_tensors[0]); | ||||
| @@ -94,19 +100,20 @@ int NPUInsertTransformPass::InsertPostNode(const InnerContext *context, kernel:: | |||||
| std::vector<Tensor *> nc2nh_tensors = {nc2nh_tensor}; | std::vector<Tensor *> nc2nh_tensors = {nc2nh_tensor}; | ||||
| all_tensors->push_back(nc2nh_tensors[0]); | all_tensors->push_back(nc2nh_tensors[0]); | ||||
| auto nh2nc_name = cur_kernel->name() + "_nh2nc_" + std::to_string(total++); | |||||
| auto *nh2nc_kernel = | |||||
| NPUPassUtils::CreateNhwc2NchwKernel(cur_kernel->out_tensors(), nh2nc_tensors, context, nh2nc_name); | |||||
| all_kernels->push_back(nh2nc_kernel); | |||||
| auto nh2nc_name = kernel->name() + "_nh2nc_" + std::to_string(total++); | |||||
| auto *nh2nc_kernel = NPUPassUtils::CreateNhwc2NchwKernel(kernel->out_tensors(), nh2nc_tensors, context, nh2nc_name); | |||||
| trans_kernels->push_back(nh2nc_kernel); | |||||
| insert_primitive_.push_back(nh2nc_kernel->GetPrimitive()); | insert_primitive_.push_back(nh2nc_kernel->GetPrimitive()); | ||||
| auto nc2nh_name = cur_kernel->name() + "_nc2nh_" + std::to_string(total++); | |||||
| auto nc2nh_name = kernel->name() + "_nc2nh_" + std::to_string(total++); | |||||
| auto *nc2nh_kernel = NPUPassUtils::CreateNchw2NhwcKernel(nh2nc_tensors, nc2nh_tensors, context, nc2nh_name); | auto *nc2nh_kernel = NPUPassUtils::CreateNchw2NhwcKernel(nh2nc_tensors, nc2nh_tensors, context, nc2nh_name); | ||||
| all_kernels->push_back(nc2nh_kernel); | |||||
| trans_kernels->push_back(nc2nh_kernel); | |||||
| insert_primitive_.push_back(nc2nh_kernel->GetPrimitive()); | insert_primitive_.push_back(nc2nh_kernel->GetPrimitive()); | ||||
| NPUPassUtils::UpdateKernel(nh2nc_kernel, {cur_kernel}, {nc2nh_kernel}, cur_kernel->out_tensors(), nh2nc_tensors); | |||||
| NPUPassUtils::UpdateKernel(nh2nc_kernel, {kernel}, {nc2nh_kernel}, kernel->out_tensors(), nh2nc_tensors); | |||||
| NPUPassUtils::UpdateKernel(nc2nh_kernel, {nh2nc_kernel}, {out_kernel}, nh2nc_tensors, nc2nh_tensors); | NPUPassUtils::UpdateKernel(nc2nh_kernel, {nh2nc_kernel}, {out_kernel}, nh2nc_tensors, nc2nh_tensors); | ||||
| NPUPassUtils::UpdateNH2NCTransNodePreKernel(cur_kernel, nh2nc_kernel, out_kernel); | |||||
| NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(cur_kernel, nc2nh_kernel, out_kernel); | |||||
| NPUPassUtils::UpdateNH2NCTransNodePreKernel(kernel, nh2nc_kernel, out_kernel); | |||||
| NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(kernel, nc2nh_kernel, out_kernel); | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -41,11 +41,11 @@ class NPUInsertTransformPass : public NPUBasePass { | |||||
| int Run() override; | int Run() override; | ||||
| private: | private: | ||||
| int InsertPreNode(const InnerContext *context, kernel::LiteKernel *cur_kernel, | |||||
| std::vector<kernel::LiteKernel *> *all_kernels, std::vector<Tensor *> *all_tensors); | |||||
| int InsertPreNode(const InnerContext *context, kernel::LiteKernel *kernel, | |||||
| std::vector<kernel::LiteKernel *> *trans_kernels, std::vector<Tensor *> *all_tensors); | |||||
| int InsertPostNode(const InnerContext *context, kernel::LiteKernel *cur_kernel, | |||||
| std::vector<kernel::LiteKernel *> *all_kernels, std::vector<Tensor *> *all_tensors); | |||||
| int InsertPostNode(const InnerContext *context, kernel::LiteKernel *kernel, | |||||
| std::vector<kernel::LiteKernel *> *trans_kernels, std::vector<Tensor *> *all_tensors); | |||||
| private: | private: | ||||
| int total = 0; | int total = 0; | ||||
| @@ -100,25 +100,25 @@ void NPUPassUtils::UpdateKernel(kernel::LiteKernel *kernel, const std::vector<ke | |||||
| kernel->set_out_kernels(out_kernels); | kernel->set_out_kernels(out_kernels); | ||||
| } | } | ||||
| void NPUPassUtils::UpdateNH2NCTransNodePreKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel, | |||||
| kernel::LiteKernel *after_kernel) { | |||||
| void NPUPassUtils::UpdateNH2NCTransNodePreKernel(kernel::LiteKernel *pre_kernel, kernel::LiteKernel *trans_kernel, | |||||
| kernel::LiteKernel *kernel) { | |||||
| std::vector<kernel::LiteKernel *> out_kernels; | std::vector<kernel::LiteKernel *> out_kernels; | ||||
| for (auto out_kernel : kernel->out_kernels()) { | |||||
| if (out_kernel == after_kernel) { | |||||
| for (auto out_kernel : pre_kernel->out_kernels()) { | |||||
| if (out_kernel == kernel) { | |||||
| out_kernels.push_back(trans_kernel); | out_kernels.push_back(trans_kernel); | ||||
| } else { | } else { | ||||
| out_kernels.push_back(out_kernel); | out_kernels.push_back(out_kernel); | ||||
| } | } | ||||
| } | } | ||||
| UpdateKernel(kernel, kernel->in_kernels(), out_kernels, kernel->in_tensors(), kernel->out_tensors()); | |||||
| pre_kernel->set_out_kernels(out_kernels); | |||||
| } | } | ||||
| void NPUPassUtils::UpdateNC2NHTransNodePreKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel, | void NPUPassUtils::UpdateNC2NHTransNodePreKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel, | ||||
| kernel::LiteKernel *next_kernel) { | |||||
| kernel::LiteKernel *post_kernel) { | |||||
| std::vector<kernel::LiteKernel *> cur_out_kernels; | std::vector<kernel::LiteKernel *> cur_out_kernels; | ||||
| for (auto out_kernel : kernel->out_kernels()) { | for (auto out_kernel : kernel->out_kernels()) { | ||||
| if (out_kernel == next_kernel) { | |||||
| if (out_kernel == post_kernel) { | |||||
| cur_out_kernels.push_back(trans_kernel); | cur_out_kernels.push_back(trans_kernel); | ||||
| } else { | } else { | ||||
| cur_out_kernels.push_back(out_kernel); | cur_out_kernels.push_back(out_kernel); | ||||
| @@ -130,45 +130,47 @@ void NPUPassUtils::UpdateNC2NHTransNodePreKernel(kernel::LiteKernel *kernel, ker | |||||
| std::vector<int> nchw_shape = {nhwc_shape[0], nhwc_shape[3], nhwc_shape[1], nhwc_shape[2]}; | std::vector<int> nchw_shape = {nhwc_shape[0], nhwc_shape[3], nhwc_shape[1], nhwc_shape[2]}; | ||||
| kernel_out_tensor->set_format(schema::Format_NCHW); | kernel_out_tensor->set_format(schema::Format_NCHW); | ||||
| kernel_out_tensor->set_shape(nchw_shape); | kernel_out_tensor->set_shape(nchw_shape); | ||||
| UpdateKernel(kernel, kernel->in_kernels(), cur_out_kernels, kernel->in_tensors(), {kernel_out_tensor}); | |||||
| kernel->set_out_kernels(cur_out_kernels); | |||||
| kernel->set_out_tensors({kernel_out_tensor}); | |||||
| } | } | ||||
| void NPUPassUtils::UpdateNH2NCTransNodeAfterKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel, | void NPUPassUtils::UpdateNH2NCTransNodeAfterKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel, | ||||
| kernel::LiteKernel *before_kernel) { | |||||
| kernel::LiteKernel *pre_kernel) { | |||||
| std::vector<lite::Tensor *> cur_kernel_in_tensors = {trans_kernel->out_tensors()[0]}; | std::vector<lite::Tensor *> cur_kernel_in_tensors = {trans_kernel->out_tensors()[0]}; | ||||
| for (int i = 1; i < kernel->in_tensors().size(); i++) { | for (int i = 1; i < kernel->in_tensors().size(); i++) { | ||||
| cur_kernel_in_tensors.push_back(kernel->in_tensors()[i]); | cur_kernel_in_tensors.push_back(kernel->in_tensors()[i]); | ||||
| } | } | ||||
| std::vector<kernel::LiteKernel *> cur_in_kernels = {trans_kernel}; | std::vector<kernel::LiteKernel *> cur_in_kernels = {trans_kernel}; | ||||
| for (int i = 0; i < kernel->in_kernels().size(); i++) { | |||||
| for (int i = 1; i < kernel->in_kernels().size(); i++) { | |||||
| auto in_kernel = kernel->in_kernels()[i]; | auto in_kernel = kernel->in_kernels()[i]; | ||||
| if (in_kernel != kernel) { | if (in_kernel != kernel) { | ||||
| cur_in_kernels.push_back(in_kernel); | cur_in_kernels.push_back(in_kernel); | ||||
| } | } | ||||
| } | } | ||||
| UpdateKernel(kernel, cur_in_kernels, kernel->out_kernels(), cur_kernel_in_tensors, kernel->out_tensors()); | |||||
| kernel->set_in_kernels(cur_in_kernels); | |||||
| kernel->set_in_tensors({cur_kernel_in_tensors}); | |||||
| } | } | ||||
| void NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel, | void NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel, | ||||
| kernel::LiteKernel *next_kernel) { | |||||
| std::vector<Tensor *> next_in_tensors; | |||||
| for (auto next_in_tensor : next_kernel->in_tensors()) { | |||||
| if (next_in_tensor != kernel->out_tensors()[0]) { | |||||
| next_in_tensors.push_back(next_in_tensor); | |||||
| kernel::LiteKernel *post_kernel) { | |||||
| std::vector<Tensor *> post_in_tensors; | |||||
| for (auto post_in_tensor : post_kernel->in_tensors()) { | |||||
| if (post_in_tensor != kernel->out_tensors()[0]) { | |||||
| post_in_tensors.push_back(post_in_tensor); | |||||
| } else { | } else { | ||||
| next_in_tensors.push_back(trans_kernel->out_tensors()[0]); | |||||
| post_in_tensors.push_back(trans_kernel->out_tensors()[0]); | |||||
| } | } | ||||
| } | } | ||||
| next_kernel->set_in_tensors(next_in_tensors); | |||||
| std::vector<kernel::LiteKernel *> next_in_kernels; | |||||
| for (auto in_kernel : next_kernel->in_kernels()) { | |||||
| post_kernel->set_in_tensors(post_in_tensors); | |||||
| std::vector<kernel::LiteKernel *> post_in_kernels; | |||||
| for (auto in_kernel : post_kernel->in_kernels()) { | |||||
| if (in_kernel == kernel) { | if (in_kernel == kernel) { | ||||
| next_in_kernels.push_back(trans_kernel); | |||||
| post_in_kernels.push_back(trans_kernel); | |||||
| } else { | } else { | ||||
| next_in_kernels.push_back(in_kernel); | |||||
| post_in_kernels.push_back(in_kernel); | |||||
| } | } | ||||
| } | } | ||||
| NPUPassUtils::UpdateKernel(next_kernel, next_in_kernels, next_kernel->out_kernels(), next_in_tensors, | |||||
| next_kernel->out_tensors()); | |||||
| post_kernel->set_in_kernels(post_in_kernels); | |||||
| post_kernel->set_in_tensors({post_in_tensors}); | |||||
| } | } | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -35,17 +35,17 @@ class NPUPassUtils { | |||||
| const std::vector<kernel::LiteKernel *> &out_kernels, | const std::vector<kernel::LiteKernel *> &out_kernels, | ||||
| const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors); | const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors); | ||||
| static void UpdateNH2NCTransNodePreKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel, | |||||
| kernel::LiteKernel *after_kernel); | |||||
| static void UpdateNH2NCTransNodePreKernel(kernel::LiteKernel *pre_kernel, kernel::LiteKernel *trans_kernel, | |||||
| kernel::LiteKernel *kernel); | |||||
| static void UpdateNC2NHTransNodePreKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel, | static void UpdateNC2NHTransNodePreKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel, | ||||
| kernel::LiteKernel *next_kernel); | |||||
| kernel::LiteKernel *post_kernel); | |||||
| static void UpdateNH2NCTransNodeAfterKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel, | static void UpdateNH2NCTransNodeAfterKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel, | ||||
| kernel::LiteKernel *before_kernel); | |||||
| kernel::LiteKernel *pre_kernel); | |||||
| static void UpdateNC2NHTransNodeAfterKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel, | static void UpdateNC2NHTransNodeAfterKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel, | ||||
| kernel::LiteKernel *next_kernel); | |||||
| kernel::LiteKernel *post_kernel); | |||||
| private: | private: | ||||
| static PrimitiveC *CreateNchw2NhwcPrimitive(); | static PrimitiveC *CreateNchw2NhwcPrimitive(); | ||||
| @@ -19,51 +19,53 @@ | |||||
| #include "src/runtime/agent/npu/npu_manager.h" | #include "src/runtime/agent/npu/npu_manager.h" | ||||
| #include "src/runtime/agent/npu/optimizer/npu_pass_utils.h" | #include "src/runtime/agent/npu/optimizer/npu_pass_utils.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| using kernel::KERNEL_ARCH::kCPU; | |||||
| using kernel::KERNEL_ARCH::kNPU; | using kernel::KERNEL_ARCH::kNPU; | ||||
| int NPUTransformPass::InsertPreNode(const InnerContext *context, kernel::LiteKernel *kernel, | int NPUTransformPass::InsertPreNode(const InnerContext *context, kernel::LiteKernel *kernel, | ||||
| std::vector<kernel::LiteKernel *> *all_kernels, | |||||
| std::vector<kernel::LiteKernel *> *trans_kernels, | |||||
| std::vector<Tensor *> *all_tensors) { | std::vector<Tensor *> *all_tensors) { | ||||
| bool is_input_kernel = kernel->in_kernels().empty(); | bool is_input_kernel = kernel->in_kernels().empty(); | ||||
| if (is_input_kernel || kernel->in_kernels()[0]->desc().arch != kNPU || | if (is_input_kernel || kernel->in_kernels()[0]->desc().arch != kNPU || | ||||
| npu_trans_nodes.find(kernel->in_kernels()[0]->Type()) == npu_trans_nodes.end()) { | npu_trans_nodes.find(kernel->in_kernels()[0]->Type()) == npu_trans_nodes.end()) { | ||||
| kernel::LiteKernel *before_kernel = nullptr; | |||||
| kernel::LiteKernel *pre_kernel = nullptr; | |||||
| if (!is_input_kernel) { | if (!is_input_kernel) { | ||||
| before_kernel = kernel->in_kernels()[0]; | |||||
| pre_kernel = kernel->in_kernels()[0]; | |||||
| } | } | ||||
| // Create pre transform kernel out tensors. | |||||
| // Create pre transform kernel's out tensor. | |||||
| auto nhwc_shape = kernel->in_tensors()[0]->shape(); | auto nhwc_shape = kernel->in_tensors()[0]->shape(); | ||||
| std::vector<int> nchw_shape = {nhwc_shape[0], nhwc_shape[3], nhwc_shape[1], nhwc_shape[2]}; | std::vector<int> nchw_shape = {nhwc_shape[0], nhwc_shape[3], nhwc_shape[1], nhwc_shape[2]}; | ||||
| auto tensor = new Tensor(kernel->in_tensors()[0]->data_type(), nchw_shape, schema::Format_NCHW, Tensor::VAR); | auto tensor = new Tensor(kernel->in_tensors()[0]->data_type(), nchw_shape, schema::Format_NCHW, Tensor::VAR); | ||||
| std::vector<Tensor *> pre_trans_out_tensors = {tensor}; | std::vector<Tensor *> pre_trans_out_tensors = {tensor}; | ||||
| all_tensors->push_back(pre_trans_out_tensors[0]); | all_tensors->push_back(pre_trans_out_tensors[0]); | ||||
| // Replace the output tensor of the previous node | |||||
| // Create pre transform kernel: Nhwc2Nchw | |||||
| auto name = kernel->name() + "_pre_trans" + "_Nhwc2Nchw_" + std::to_string(total++); | auto name = kernel->name() + "_pre_trans" + "_Nhwc2Nchw_" + std::to_string(total++); | ||||
| auto *pre_trans_kernel = | |||||
| auto *trans_kernel = | |||||
| NPUPassUtils::CreateNhwc2NchwKernel({kernel->in_tensors()[0]}, pre_trans_out_tensors, context, name); | NPUPassUtils::CreateNhwc2NchwKernel({kernel->in_tensors()[0]}, pre_trans_out_tensors, context, name); | ||||
| // Insert Nhwc2Nchw into the front of the current queue | |||||
| all_kernels->push_back(pre_trans_kernel); | |||||
| insert_primitive_.push_back(pre_trans_kernel->GetPrimitive()); | |||||
| // Replace the output kernel of the previous node | |||||
| trans_kernels->push_back(trans_kernel); | |||||
| insert_primitive_.push_back(trans_kernel->GetPrimitive()); | |||||
| // Set in_kernels, out_kernels, in_tensors,out_tensors for transform kernel | |||||
| std::vector<kernel::LiteKernel *> pre_trans_in_kernel; | std::vector<kernel::LiteKernel *> pre_trans_in_kernel; | ||||
| if (is_input_kernel) { | if (is_input_kernel) { | ||||
| pre_trans_in_kernel = {}; | pre_trans_in_kernel = {}; | ||||
| } else { | } else { | ||||
| pre_trans_in_kernel = {before_kernel}; | |||||
| pre_trans_in_kernel = {pre_kernel}; | |||||
| } | } | ||||
| NPUPassUtils::UpdateKernel(pre_trans_kernel, pre_trans_in_kernel, {kernel}, {kernel->in_tensors()[0]}, | |||||
| NPUPassUtils::UpdateKernel(trans_kernel, pre_trans_in_kernel, {kernel}, {kernel->in_tensors()[0]}, | |||||
| pre_trans_out_tensors); | pre_trans_out_tensors); | ||||
| if (before_kernel != nullptr) { | |||||
| NPUPassUtils::UpdateNH2NCTransNodePreKernel(before_kernel, pre_trans_kernel, kernel); | |||||
| if (pre_kernel != nullptr) { | |||||
| NPUPassUtils::UpdateNH2NCTransNodePreKernel(pre_kernel, trans_kernel, kernel); | |||||
| } | } | ||||
| NPUPassUtils::UpdateNH2NCTransNodeAfterKernel(kernel, pre_trans_kernel, before_kernel); | |||||
| NPUPassUtils::UpdateNH2NCTransNodeAfterKernel(kernel, trans_kernel, pre_kernel); | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int NPUTransformPass::InsertPostNode(const InnerContext *context, kernel::LiteKernel *kernel, | int NPUTransformPass::InsertPostNode(const InnerContext *context, kernel::LiteKernel *kernel, | ||||
| std::vector<kernel::LiteKernel *> *all_kernels, | |||||
| std::vector<kernel::LiteKernel *> *trans_kernels, | |||||
| std::vector<Tensor *> *all_tensors) { | std::vector<Tensor *> *all_tensors) { | ||||
| // Model output does not insert operator | // Model output does not insert operator | ||||
| if (kernel->out_kernels().empty()) { | if (kernel->out_kernels().empty()) { | ||||
| @@ -71,27 +73,30 @@ int NPUTransformPass::InsertPostNode(const InnerContext *context, kernel::LiteKe | |||||
| } | } | ||||
| // Single output multiple references | // Single output multiple references | ||||
| for (int i = 0; i < kernel->out_kernels().size(); i++) { | for (int i = 0; i < kernel->out_kernels().size(); i++) { | ||||
| auto next_kernel = kernel->out_kernels().at(i); | |||||
| if (next_kernel->desc().arch == kNPU && npu_trans_nodes.find(next_kernel->Type()) != npu_trans_nodes.end()) { | |||||
| auto post_kernel = kernel->out_kernels().at(i); | |||||
| if (post_kernel->desc().arch == kNPU && npu_trans_nodes.find(post_kernel->Type()) != npu_trans_nodes.end()) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| // Change format the output of the current kernel nhwc->nchw | |||||
| // Create post transform kernel's out tensor. | |||||
| auto tensor = new Tensor(kernel->out_tensors()[0]->data_type(), kernel->out_tensors()[0]->shape(), | auto tensor = new Tensor(kernel->out_tensors()[0]->data_type(), kernel->out_tensors()[0]->shape(), | ||||
| schema::Format_NHWC, Tensor::VAR); | schema::Format_NHWC, Tensor::VAR); | ||||
| std::vector<Tensor *> post_trans_out_tensors = {tensor}; | std::vector<Tensor *> post_trans_out_tensors = {tensor}; | ||||
| all_tensors->push_back(post_trans_out_tensors[0]); | all_tensors->push_back(post_trans_out_tensors[0]); | ||||
| // Use the output tensor of the current node as the input tensor of the post-conversion operator | |||||
| // Create post transform kernel: Nchw2Nhwc | |||||
| auto name = kernel->name() + "_post_trans" + "_Nchw2Nhwc" + std::to_string(total++); | auto name = kernel->name() + "_post_trans" + "_Nchw2Nhwc" + std::to_string(total++); | ||||
| auto *post_trans_kernel = | auto *post_trans_kernel = | ||||
| NPUPassUtils::CreateNchw2NhwcKernel(kernel->out_tensors(), post_trans_out_tensors, context, name); | NPUPassUtils::CreateNchw2NhwcKernel(kernel->out_tensors(), post_trans_out_tensors, context, name); | ||||
| // Replace the input tensor of the next node | |||||
| NPUPassUtils::UpdateKernel(post_trans_kernel, {kernel}, {next_kernel}, kernel->out_tensors(), | |||||
| // Set in_kernels, out_kernels, in_tensors,out_tensors for transform kernel | |||||
| NPUPassUtils::UpdateKernel(post_trans_kernel, {kernel}, {post_kernel}, kernel->out_tensors(), | |||||
| post_trans_out_tensors); | post_trans_out_tensors); | ||||
| insert_primitive_.push_back(post_trans_kernel->GetPrimitive()); | insert_primitive_.push_back(post_trans_kernel->GetPrimitive()); | ||||
| // Directly insert in the back, will not affect the topological sort | |||||
| all_kernels->push_back(post_trans_kernel); | |||||
| NPUPassUtils::UpdateNC2NHTransNodePreKernel(kernel, post_trans_kernel, next_kernel); | |||||
| NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(kernel, post_trans_kernel, next_kernel); | |||||
| trans_kernels->push_back(post_trans_kernel); | |||||
| NPUPassUtils::UpdateNC2NHTransNodePreKernel(kernel, post_trans_kernel, post_kernel); | |||||
| NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(kernel, post_trans_kernel, post_kernel); | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -43,10 +43,10 @@ class NPUTransformPass : public NPUBasePass { | |||||
| private: | private: | ||||
| int InsertPreNode(const InnerContext *context, kernel::LiteKernel *kernel, | int InsertPreNode(const InnerContext *context, kernel::LiteKernel *kernel, | ||||
| std::vector<kernel::LiteKernel *> *all_kernels, std::vector<Tensor *> *all_tensors); | |||||
| std::vector<kernel::LiteKernel *> *trans_kernels, std::vector<Tensor *> *all_tensors); | |||||
| int InsertPostNode(const InnerContext *context, kernel::LiteKernel *kernel, | int InsertPostNode(const InnerContext *context, kernel::LiteKernel *kernel, | ||||
| std::vector<kernel::LiteKernel *> *all_kernels, std::vector<Tensor *> *all_tensors); | |||||
| std::vector<kernel::LiteKernel *> *trans_kernels, std::vector<Tensor *> *all_tensors); | |||||
| private: | private: | ||||
| int total = 0; | int total = 0; | ||||