diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index 93a7f0ef9c..62d9536892 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -356,18 +356,16 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptrname() == mindspore::ops::kNameDepend || prim->name() == mindspore::ops::kNameControlDepend) { + if (prim->name() == mindspore::ops::kNameDepend || prim->name() == mindspore::ops::kNameControlDepend || + prim->name() == mindspore::ops::kNameTupleGetItem || prim->name() == mindspore::ops::kNameMakeTuple) { continue; } if (prim->name() == "make_tuple") { continue; } + RemoveIfMakeTuple(cnode); - if (prim->name() == mindspore::ops::kNameTupleGetItem || prim->name() == mindspore::ops::kNameMakeTuple) { - continue; - } auto node = std::make_unique(); if (node == nullptr) { MS_LOG(ERROR) << "object failed to be constructed"; diff --git a/mindspore/lite/tools/common/node_util.cc b/mindspore/lite/tools/common/node_util.cc index 33b3313193..d7f79b6952 100644 --- a/mindspore/lite/tools/common/node_util.cc +++ b/mindspore/lite/tools/common/node_util.cc @@ -137,10 +137,11 @@ static const std::vector int8OpList = {schema::PrimitiveT schema::PrimitiveType_L2NormalizeFusion}; static const std::vector needInsertOpList = { - schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, - schema::PrimitiveType_PowFusion, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_AddFusion, - schema::PrimitiveType_Split, schema::PrimitiveType_SliceFusion, schema::PrimitiveType_Crop, - schema::PrimitiveType_MulFusion, schema::PrimitiveType_Maximum, schema::PrimitiveType_ActivationGrad}; + schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, + schema::PrimitiveType_PowFusion, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_AddFusion, + schema::PrimitiveType_AddN, schema::PrimitiveType_Split, schema::PrimitiveType_SliceFusion, + schema::PrimitiveType_Crop, schema::PrimitiveType_MulFusion, schema::PrimitiveType_Maximum, + schema::PrimitiveType_ActivationGrad}; static const std::unordered_map nc2NhAxisMap = {{0, 0}, {1, -1}, {2, 1}, {3, 2}}; diff --git a/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc b/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc index 7cb738c0e5..87dde56380 100644 --- a/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc @@ -15,12 +15,28 @@ */ #include "tools/optimizer/graph/redundant_op_remove_pass.h" -#include "mindspore/lite/include/errorcode.h" +#include +#include +#include "include/errorcode.h" +#include "ops/make_tuple.h" namespace mindspore::opt { namespace { constexpr size_t InputDoubleNum = 2; constexpr size_t InputTripleNum = 3; +void FetchCNodeFromMakeTuple(const AnfNodePtr &anf_node, std::vector *inputs) { + MS_ASSERT(anf_node != nullptr); + MS_ASSERT(inputs != nullptr); + auto cnode = anf_node->cast(); + if (cnode == nullptr) { + return; + } + for (size_t i = 1; i < cnode->size(); ++i) { + if (cnode->input(i)->isa()) { + inputs->push_back(cnode->input(i)); + } + } +} } // namespace int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) { if (!utils::isa(anf_node)) { @@ -58,6 +74,36 @@ int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraph return RET_OK; } +int RemoveRedundantOpPass::ReplaceUpdateStateOp(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node) { + if (!utils::isa(anf_node)) { + MS_LOG(DEBUG) << "anf node is node a cnode."; + return lite::RET_NO_CHANGE; + } + auto cnode = anf_node->cast(); + auto inputs = cnode->inputs(); + std::vector new_inputs; + for (size_t i = 1; i < inputs.size(); ++i) { + if (!inputs[i]->isa()) { + continue; + } + if (CheckPrimitiveType(inputs[i], prim::kPrimMakeTuple)) { + FetchCNodeFromMakeTuple(inputs[i], &new_inputs); + continue; + } + new_inputs.push_back(inputs[i]); + } + for (auto &node : new_inputs) { + func_graph->get_return()->add_input(node); + } + auto value = std::make_shared(); + bool replace_succ = func_graph->manager()->Replace(anf_node, NewValueNode(value)); + if (!replace_succ) { + MS_LOG(ERROR) << "replace redundant op failed."; + return lite::RET_ERROR; + } + return RET_OK; +} + int RemoveRedundantOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) { if (!utils::isa(anf_node)) { MS_LOG(DEBUG) << "anf node is node a cnode."; @@ -111,7 +157,7 @@ bool RemoveRedundantOpPass::Run(const FuncGraphPtr &func_graph) { status = ReplaceOp(node, manager); } if (CheckPrimitiveType(node, prim::kPrimUpdateState)) { - status = ReplaceOp(node, manager); + status = ReplaceUpdateStateOp(func_graph, node); } if (CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { status = ReplaceTupleGetItem(node, manager); diff --git a/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.h b/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.h index 1bef786eaa..ad7532aa70 100644 --- a/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.h +++ b/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.h @@ -29,6 +29,7 @@ class RemoveRedundantOpPass : public Pass { RemoveRedundantOpPass() : Pass("remove_redundant_op_pass") {} ~RemoveRedundantOpPass() override = default; int ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager); + int ReplaceUpdateStateOp(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node); int ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager); bool Run(const FuncGraphPtr &graph) override;