|
|
|
@@ -20,31 +20,81 @@ |
|
|
|
|
|
|
|
namespace mindspore::lite { |
|
|
|
using kernel::KERNEL_ARCH::kNPU; |
|
|
|
enum InsertState { InsertNone, PreInsert, PostInsert }; |
|
|
|
|
|
|
|
enum InsertState { InsertNone, PreInsert, PostInsert, BothInsert }; |
|
|
|
std::set<mindspore::schema::PrimitiveType> npu_insert_nodes = {schema::PrimitiveType_Concat, schema::PrimitiveType_Add, |
|
|
|
schema::PrimitiveType_Eltwise, |
|
|
|
schema::PrimitiveType_Activation}; |
|
|
|
// this pass goal is to minimize subgraphs generated |
|
|
|
// by inserting nchw2nhwc or nhwc2nchw before or after the operator (e.g. concat, add, etc..) together with |
|
|
|
// fusion pass. If transpose inserted are more than half of input output, we will insert remaining input |
|
|
|
// output with transpose and hopefully do a fusion pass. Otherwise, we don't insert anything. |
|
|
|
// |
|
|
|
// Typically concat accept output from nchw2nhwc, we fill other input with nh2nc and nc2nh so that inputs to concat are |
|
|
|
// format same and then fusion all nchw2nhwc op. |
|
|
|
// e.g. |
|
|
|
// original (conv->nchw2nhwc, add(format nhwc)) -> concat-> (nhwc2nchw->conv) |
|
|
|
// current pass (conv->nchw2nhwc, add->nhwc2nchw->nchw2nhwc) -> concat -> (nhwc2nchw->conv) |
|
|
|
// fusion pass (conv, add->nhwc2nchw) -> concat -> conv |
|
|
|
// original 2 cpusubgraph, after 2 pass, only 1 cpu subgraph |
|
|
|
// |
|
|
|
// node: |
|
|
|
// Such ops require inputs all have same format, could be nchw or nhwc or other format. |
|
|
|
// Their inputs outputs may not be 4d, or are already format ok, |
|
|
|
// so we won't insert nc2nh or nh2nc when op's in kernels and out kernels contains no nc2nh or nh2nc. |
|
|
|
// This pass should be run after npu_transform_pass, which insert transpose for nchw-input-limited op like conv2d. |
|
|
|
|
|
|
|
int GetInsertState(kernel::LiteKernel *kernel) { |
|
|
|
// filter out irrelevant kernel |
|
|
|
if (npu_insert_nodes.find(kernel->Type()) == npu_insert_nodes.end()) { |
|
|
|
return InsertNone; |
|
|
|
} |
|
|
|
auto pre_flag = std::all_of(kernel->in_kernels().begin(), kernel->in_kernels().end(), |
|
|
|
[](const kernel::LiteKernel *kernel) { return NPUPassUtils::IsNchw2Nhwc(kernel); }); |
|
|
|
auto post_flag = std::all_of(kernel->out_kernels().begin(), kernel->out_kernels().end(), |
|
|
|
[](const kernel::LiteKernel *kernel) { return NPUPassUtils::IsNhwc2Nchw(kernel); }); |
|
|
|
if (pre_flag && !post_flag) { |
|
|
|
return PostInsert; |
|
|
|
|
|
|
|
// current kernel is target kernel |
|
|
|
// use out kernels to count how many out lines from current kernel |
|
|
|
size_t in_out_tensor_num = kernel->in_tensors().size() + kernel->out_kernels().size(); |
|
|
|
size_t transpose_input_num = 0; |
|
|
|
size_t transpose_output_num = 0; |
|
|
|
bool need_pre_insert = false; |
|
|
|
bool need_post_insert = false; |
|
|
|
// count number of input tensor from nc2nh and output tensor to nh2nc |
|
|
|
for (size_t i = 0; i < kernel->in_tensors().size(); ++i) { |
|
|
|
auto in_kernel = NPUPassUtils::KernelInputFromKernel(kernel, i); |
|
|
|
if (NPUPassUtils::IsNchw2Nhwc(in_kernel)) { |
|
|
|
transpose_input_num++; |
|
|
|
} else { |
|
|
|
need_pre_insert = true; |
|
|
|
} |
|
|
|
} |
|
|
|
for (const auto out_kernel : kernel->out_kernels()) { |
|
|
|
if (NPUPassUtils::IsNhwc2Nchw(out_kernel)) { |
|
|
|
transpose_output_num++; |
|
|
|
} else { |
|
|
|
need_post_insert = true; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// won't insert any thing if num of transpose tensor is smaller than half of total input output. |
|
|
|
// won't insert if total input output are all transpose tensor, the fusion pass will handle this. |
|
|
|
size_t transpose_tensor_num = transpose_input_num + transpose_output_num; |
|
|
|
if (transpose_tensor_num <= in_out_tensor_num / 2 || transpose_tensor_num == in_out_tensor_num) { |
|
|
|
return InsertNone; |
|
|
|
} |
|
|
|
if (!pre_flag && post_flag) { |
|
|
|
|
|
|
|
if (need_pre_insert && !need_post_insert) { |
|
|
|
return PreInsert; |
|
|
|
} |
|
|
|
if (need_pre_insert && need_post_insert) { |
|
|
|
return BothInsert; |
|
|
|
} |
|
|
|
if (!need_pre_insert && need_post_insert) { |
|
|
|
return PostInsert; |
|
|
|
} |
|
|
|
|
|
|
|
return InsertNone; |
|
|
|
} |
|
|
|
|
|
|
|
int NPUInsertTransformPass::InsertNode(kernel::LiteKernel *kernel, kernel::LiteKernel *post_kernel, |
|
|
|
std::vector<kernel::LiteKernel *> *trans_kernels) { |
|
|
|
size_t post_input_index, std::vector<kernel::LiteKernel *> *trans_kernels) { |
|
|
|
// Kernel and post_kernel can't be nullptr at the same time. |
|
|
|
std::string kernel_name; |
|
|
|
Tensor *in_tensor = nullptr; |
|
|
|
@@ -54,7 +104,7 @@ int NPUInsertTransformPass::InsertNode(kernel::LiteKernel *kernel, kernel::LiteK |
|
|
|
if (post_kernel != nullptr) { |
|
|
|
out_kernels.push_back(post_kernel); |
|
|
|
kernel_name = post_kernel->name() + "_pre"; |
|
|
|
in_tensor = post_kernel->in_tensors()[0]; |
|
|
|
in_tensor = post_kernel->in_tensors().at(post_input_index); |
|
|
|
} |
|
|
|
std::vector<kernel::LiteKernel *> in_kernels; |
|
|
|
// If kernel equals nullptr, post_kernel is the input of whole graph. |
|
|
|
@@ -99,87 +149,134 @@ int NPUInsertTransformPass::InsertNode(kernel::LiteKernel *kernel, kernel::LiteK |
|
|
|
} |
|
|
|
if (post_kernel != nullptr) { |
|
|
|
NPUPassUtils::UpdateNC2NHTransNodePostKernel(kernel, nc2nh_kernel, post_kernel); |
|
|
|
} else { |
|
|
|
// post_kernel nullptr mean output, we remain graph output tensor name unchanged |
|
|
|
auto graph_output_name = in_tensor->tensor_name(); |
|
|
|
in_tensor->set_tensor_name(graph_output_name + "_before_" + name_); |
|
|
|
nc2nh_tensor->set_tensor_name(graph_output_name); |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int NPUInsertTransformPass::InsertForInputTensor(kernel::LiteKernel *kernel, size_t in_tensor_index, |
|
|
|
kernel::LiteKernel *pre_kernel, |
|
|
|
std::vector<kernel::LiteKernel *> *trans_kernels) { |
|
|
|
// insert transpose nodes before target ops |
|
|
|
return InsertNode(pre_kernel, kernel, in_tensor_index, trans_kernels); |
|
|
|
} |
|
|
|
|
|
|
|
int NPUInsertTransformPass::InsertForOutputTensor(kernel::LiteKernel *kernel, kernel::LiteKernel *post_kernel, |
|
|
|
size_t post_in_tensor_index, |
|
|
|
std::vector<kernel::LiteKernel *> *trans_kernels) { |
|
|
|
// insert transpose nodes after target ops |
|
|
|
return InsertNode(kernel, post_kernel, post_in_tensor_index, trans_kernels); |
|
|
|
} |
|
|
|
|
|
|
|
int NPUInsertTransformPass::InsertPreNodes(kernel::LiteKernel *kernel, |
|
|
|
std::vector<kernel::LiteKernel *> *trans_kernels) { |
|
|
|
if (kernel->in_kernels().size() != kernel->in_tensors().size()) { |
|
|
|
MS_LOG(DEBUG) << "The input tensors of kernel may be the input of whole graph or const tensor."; |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
if (kernel->in_kernels().empty()) { |
|
|
|
auto ret = InsertNode(nullptr, kernel, trans_kernels); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel before kernel " << kernel->name() << " failed."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
} |
|
|
|
for (auto in_kernel : kernel->in_kernels()) { |
|
|
|
if (NPUPassUtils::IsNchw2Nhwc(in_kernel)) { |
|
|
|
int ret = RET_OK; |
|
|
|
for (size_t i = 0; i < kernel->in_tensors().size(); ++i) { |
|
|
|
auto pre_kernel = NPUPassUtils::KernelInputFromKernel(kernel, i); |
|
|
|
if (NPUPassUtils::IsNchw2Nhwc(pre_kernel)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto ret = InsertNode(in_kernel, kernel, trans_kernels); |
|
|
|
// if this tensor is input of graph, pre_kernel is nullptr. |
|
|
|
ret = InsertForInputTensor(kernel, i, pre_kernel, trans_kernels); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel before kernel " << kernel->name() << " failed."; |
|
|
|
return RET_ERROR; |
|
|
|
return ret; |
|
|
|
} |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
int NPUInsertTransformPass::InsertPostNodes(kernel::LiteKernel *kernel, |
|
|
|
std::vector<kernel::LiteKernel *> *trans_kernels) { |
|
|
|
if (kernel->out_kernels().empty()) { |
|
|
|
auto ret = InsertNode(kernel, nullptr, trans_kernels); |
|
|
|
int ret = RET_OK; |
|
|
|
|
|
|
|
for (const auto post_kernel : kernel->out_kernels()) { |
|
|
|
if (NPUPassUtils::IsNhwc2Nchw(post_kernel)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto post_kernel_in_tensors = post_kernel->in_tensors(); |
|
|
|
// kernel's out tensor is one of post_kernel's input tensor |
|
|
|
auto it = std::find(post_kernel_in_tensors.begin(), post_kernel_in_tensors.end(), kernel->out_tensors().at(0)); |
|
|
|
if (it == post_kernel_in_tensors.end()) { |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
size_t input_index = it - post_kernel_in_tensors.begin(); |
|
|
|
ret = InsertForOutputTensor(kernel, post_kernel, input_index, trans_kernels); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel after kernel " << kernel->name() << " failed."; |
|
|
|
return RET_ERROR; |
|
|
|
return ret; |
|
|
|
} |
|
|
|
} |
|
|
|
for (auto out_kernel : kernel->out_kernels()) { |
|
|
|
if (NPUPassUtils::IsNhwc2Nchw(out_kernel)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto ret = InsertNode(kernel, out_kernel, trans_kernels); |
|
|
|
if (kernel->out_tensors().size() > kernel->out_kernels().size()) { |
|
|
|
// kernel out is graph output |
|
|
|
ret = InsertForOutputTensor(kernel, nullptr, 0, trans_kernels); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel after kernel " << kernel->name() << " failed."; |
|
|
|
return RET_ERROR; |
|
|
|
return ret; |
|
|
|
} |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
int NPUInsertTransformPass::Run() { |
|
|
|
std::vector<kernel::LiteKernel *> insert_kernels; |
|
|
|
for (size_t i = 0; i < all_kernels_->size(); i++) { |
|
|
|
auto kernel = (*all_kernels_)[i]; |
|
|
|
if (kernel->desc().arch != kNPU) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto insert_state = GetInsertState(kernel); |
|
|
|
insert_kernels.clear(); |
|
|
|
// If the every output kernel is nhwc2nchw, insert |
|
|
|
// modify loop index add post_kernels.size() to the next kernel in the origin vector |
|
|
|
if (insert_state == PreInsert) { |
|
|
|
std::vector<kernel::LiteKernel *> pre_kernels; |
|
|
|
auto ret = InsertPreNodes(kernel, &pre_kernels); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel before kernel " << kernel->name() << " failed."; |
|
|
|
return RET_ERROR; |
|
|
|
switch (insert_state) { |
|
|
|
case PreInsert: { |
|
|
|
auto ret = InsertPreNodes(kernel, &insert_kernels); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel before kernel " << kernel->name() |
|
|
|
<< " failed."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
all_kernels_->insert(all_kernels_->begin() + i, insert_kernels.begin(), insert_kernels.end()); |
|
|
|
i += insert_kernels.size(); |
|
|
|
break; |
|
|
|
} |
|
|
|
all_kernels_->insert(all_kernels_->begin() + i, pre_kernels.begin(), pre_kernels.end()); |
|
|
|
i += pre_kernels.size(); |
|
|
|
} |
|
|
|
case PostInsert: { |
|
|
|
auto ret = InsertPostNodes(kernel, &insert_kernels); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel after kernel " << kernel->name() << " failed."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
all_kernels_->insert(all_kernels_->begin() + i + 1, insert_kernels.begin(), insert_kernels.end()); |
|
|
|
i += insert_kernels.size(); |
|
|
|
break; |
|
|
|
} |
|
|
|
case BothInsert: { |
|
|
|
auto ret = InsertPreNodes(kernel, &insert_kernels); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel before kernel " << kernel->name() |
|
|
|
<< " failed."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
all_kernels_->insert(all_kernels_->begin() + i, insert_kernels.begin(), insert_kernels.end()); |
|
|
|
i += insert_kernels.size(); |
|
|
|
|
|
|
|
if (insert_state == PostInsert) { |
|
|
|
std::vector<kernel::LiteKernel *> post_kernels; |
|
|
|
auto ret = InsertPostNodes(kernel, &post_kernels); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel after kernel " << kernel->name() << " failed."; |
|
|
|
return RET_ERROR; |
|
|
|
insert_kernels.clear(); |
|
|
|
ret = InsertPostNodes(kernel, &insert_kernels); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel after kernel " << kernel->name() << " failed."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
all_kernels_->insert(all_kernels_->begin() + i + 1, insert_kernels.begin(), insert_kernels.end()); |
|
|
|
i += insert_kernels.size(); |
|
|
|
break; |
|
|
|
} |
|
|
|
all_kernels_->insert(all_kernels_->begin() + i + 1, post_kernels.begin(), post_kernels.end()); |
|
|
|
i += post_kernels.size(); |
|
|
|
default: |
|
|
|
MS_LOG(DEBUG) << "Insert Nothing on kernel " << kernel->name(); |
|
|
|
} |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
|