| @@ -225,7 +225,7 @@ if(ENABLE_CONVERTER) | |||
| ${LITE_DIR}/tools/optimizer/graph/update_conv2d_param_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/unused_transpose_node_remove_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/identity_remove_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/redundant_op_remove_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/infershape_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/mindir_adjust_pass.cc | |||
| @@ -59,41 +59,6 @@ void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) { | |||
| } | |||
| } | |||
| void AnfExporter::RemoveIfDepend(const CNodePtr &cnode) { | |||
| bool hasDepend = false; | |||
| std::vector<AnfNodePtr> inputs; | |||
| inputs.clear(); | |||
| inputs.emplace_back(cnode->input(0)); | |||
| for (size_t i = 1; i < cnode->inputs().size(); ++i) { | |||
| AnfNodePtr inputNode = cnode->input(i); | |||
| if (!inputNode->isa<CNode>()) { | |||
| inputs.emplace_back(cnode->input(i)); | |||
| continue; | |||
| } | |||
| auto dependNode = utils::cast<CNodePtr>(inputNode); | |||
| if (IsPrimitiveCNode(dependNode, schema::PrimitiveType_Depend) || | |||
| IsPrimitiveCNode(dependNode, schema::PrimitiveType_ControlDepend)) { | |||
| hasDepend = true; | |||
| bool maskOut = (dependNode->inputs().size() == 3); | |||
| for (size_t j = 1; j < dependNode->inputs().size(); ++j) { | |||
| AnfNodePtr dependInputNode = dependNode->input(j); | |||
| if (dependInputNode->isa<CNode>()) { | |||
| inputs.emplace_back(dependInputNode); | |||
| if (maskOut) { | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } else { | |||
| inputs.emplace_back(cnode->input(i)); | |||
| } | |||
| } | |||
| if (hasDepend) { | |||
| cnode->set_inputs(inputs); | |||
| } | |||
| } | |||
| int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph, | |||
| const std::shared_ptr<PrimitiveC> &primitive, | |||
| const std::unique_ptr<schema::CNodeT> &dst_node) { | |||
| @@ -286,23 +251,11 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc | |||
| break; | |||
| } | |||
| } | |||
| #ifdef SUPPORT_TRAIN | |||
| RemoveIfMakeTuple(cnode); | |||
| RemoveIfDepend(cnode); | |||
| #endif | |||
| if ((primitive_c->Type() == schema::PrimitiveType_TupleGetItem) || | |||
| #ifdef SUPPORT_TRAIN | |||
| (primitive_c->Type() == schema::PrimitiveType_Depend) || | |||
| (primitive_c->Type() == schema::PrimitiveType_ControlDepend) || | |||
| #endif | |||
| (primitive_c->Type() == schema::PrimitiveType_MakeTuple)) { | |||
| continue; | |||
| } | |||
| #ifndef SUPPORT_TRAIN | |||
| RemoveIfMakeTuple(cnode); | |||
| #endif | |||
| auto primT = primitive_c->primitiveT(); | |||
| auto node = std::make_unique<schema::CNodeT>(); | |||
| if (node == nullptr) { | |||
| @@ -41,7 +41,6 @@ class AnfExporter { | |||
| int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||
| schema::CNodeT *fb_node); | |||
| static void RemoveIfMakeTuple(const CNodePtr &cnode); | |||
| static void RemoveIfDepend(const CNodePtr &cnode); | |||
| protected: | |||
| int ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, schema::CNodeT *output_cnode); | |||
| @@ -59,7 +59,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| ../optimizer/graph/update_conv2d_param_pass.cc | |||
| ../optimizer/graph/unused_cast_node_remove_pass.cc | |||
| ../optimizer/graph/unused_transpose_node_remove_pass.cc | |||
| ../optimizer/graph/identity_remove_pass.cc | |||
| ../optimizer/graph/redundant_op_remove_pass.cc | |||
| ../optimizer/graph/infershape_pass.cc | |||
| ../optimizer/graph/slice_prepose_pass.cc | |||
| ../optimizer/graph/mindir_adjust_pass.cc | |||
| @@ -34,7 +34,7 @@ | |||
| #include "tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h" | |||
| #include "tools/optimizer/graph/mindir_adjust_pass.h" | |||
| #include "tools/optimizer/graph/mindir_inputs_adjust_pass.h" | |||
| #include "tools/optimizer/graph/identity_remove_pass.h" | |||
| #include "tools/optimizer/graph/redundant_op_remove_pass.h" | |||
| #include "tools/optimizer/graph/weight_format_hardcode_pass.h" | |||
| #include "tools/optimizer/graph/weight_format_transform_pass.h" | |||
| #include "tools/optimizer/graph/clip_convert_activation_pass.h" | |||
| @@ -144,7 +144,7 @@ int AnfTransform::AddConvertPass(const std::shared_ptr<opt::GraphOptimizer> &opt | |||
| int AnfTransform::AddConstFoldPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, | |||
| const converter::Flags *config) { | |||
| auto const_fold_pm = std::make_shared<opt::PassManager>("const fold fusion pass manager", false); | |||
| const_fold_pm->AddPass(std::make_shared<opt::RemoveIdentityOpPass>()); | |||
| const_fold_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>()); | |||
| if (!config->trainModel) { | |||
| auto inne_context_ptr = std::make_shared<lite::InnerContext>(); | |||
| inne_context_ptr->Init(); | |||
| @@ -13,37 +13,41 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/graph/identity_remove_pass.h" | |||
| #include "tools/optimizer/graph/redundant_op_remove_pass.h" | |||
| #include <memory> | |||
| #include "mindspore/lite/include/errorcode.h" | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore::opt { | |||
| int RemoveIdentityOpPass::ReplaceIdentity(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) { | |||
| namespace { | |||
| constexpr size_t InputDoubleNum = 2; | |||
| constexpr size_t InputTripleNum = 3; | |||
| constexpr auto kNameLoad = "Load"; | |||
| constexpr auto kNameUpdateState = "UpdateState"; | |||
| } // namespace | |||
| int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) { | |||
| if (!utils::isa<CNodePtr>(anf_node)) { | |||
| MS_LOG(DEBUG) << "anf node is node a cnode."; | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| auto type = opt::GetCNodeType(anf_node); | |||
| if (type != schema::PrimitiveType_Identity) { | |||
| MS_LOG(DEBUG) << "anf node is not a identity node."; | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| auto identity_cnode = anf_node->cast<CNodePtr>(); | |||
| if (identity_cnode->inputs().size() != lite::kDoubleNum) { | |||
| MS_LOG(DEBUG) << "The node inputs size is bigger than 1"; | |||
| remove_cnode_.insert(anf_node); | |||
| return lite::RET_NO_CHANGE; | |||
| } else { | |||
| bool replace_succ = manager->Replace(anf_node, identity_cnode->input(1)); | |||
| if (!replace_succ) { | |||
| MS_LOG(ERROR) << "replace identity failed."; | |||
| return lite::RET_ERROR; | |||
| auto cnode = anf_node->cast<CNodePtr>(); | |||
| if (type == schema::PrimitiveType_Identity) { | |||
| if (cnode->size() != InputDoubleNum) { | |||
| MS_LOG(DEBUG) << "The node inputs size is bigger than 1"; | |||
| remove_cnode_.insert(anf_node); | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| } | |||
| bool replace_succ = manager->Replace(anf_node, cnode->input(1)); | |||
| if (!replace_succ) { | |||
| MS_LOG(ERROR) << "replace redundant op failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int RemoveIdentityOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) { | |||
| int RemoveRedundantOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) { | |||
| if (!utils::isa<CNodePtr>(anf_node)) { | |||
| MS_LOG(DEBUG) << "anf node is node a cnode."; | |||
| return lite::RET_NO_CHANGE; | |||
| @@ -53,7 +57,7 @@ int RemoveIdentityOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| auto cnode = anf_node->cast<CNodePtr>(); | |||
| if (cnode->inputs().size() != 3) { | |||
| if (cnode->inputs().size() != InputTripleNum) { | |||
| MS_LOG(ERROR) << "TupleGetItem should have 3 inputs, got " << cnode->inputs().size(); | |||
| return RET_ERROR; | |||
| } | |||
| @@ -81,7 +85,7 @@ int RemoveIdentityOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const | |||
| return lite::RET_OK; | |||
| } | |||
| bool RemoveIdentityOpPass::Run(const FuncGraphPtr &func_graph) { | |||
| bool RemoveRedundantOpPass::Run(const FuncGraphPtr &func_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| auto manager = func_graph->manager(); | |||
| MS_ASSERT(manager != nullptr); | |||
| @@ -93,10 +97,22 @@ bool RemoveIdentityOpPass::Run(const FuncGraphPtr &func_graph) { | |||
| } | |||
| auto type = opt::GetCNodeType(node); | |||
| if (type == schema::PrimitiveType_Identity) { | |||
| status = ReplaceIdentity(node, manager); | |||
| } else if (type == schema::PrimitiveType_TupleGetItem) { | |||
| status = ReplaceOp(node, manager); | |||
| } | |||
| if (CheckPrimitiveType(node, std::make_shared<Primitive>(kNameLoad))) { | |||
| status = ReplaceOp(node, manager); | |||
| } | |||
| if (CheckPrimitiveType(node, std::make_shared<Primitive>(kNameUpdateState))) { | |||
| status = ReplaceOp(node, manager); | |||
| } | |||
| if (type == schema::PrimitiveType_Depend || | |||
| type == schema::PrimitiveType_ControlDepend) { // ControlDepend delete next version. | |||
| status = ReplaceOp(node, manager); | |||
| } | |||
| if (type == schema::PrimitiveType_TupleGetItem) { | |||
| status = ReplaceTupleGetItem(node, manager); | |||
| } else if (type == schema::PrimitiveType_If || type == schema::PrimitiveType_While) { | |||
| } | |||
| if (type == schema::PrimitiveType_If || type == schema::PrimitiveType_While) { | |||
| auto sub_func_graph = GetValueNode<FuncGraphPtr>(node->cast<CNodePtr>()->input(1)); | |||
| if (sub_func_graph == nullptr) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_ | |||
| #define MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_ | |||
| #ifndef MINDSPORE_LITE_SRC_PASS_REDUNDANT_OP_REMOVE_PASS_H_ | |||
| #define MINDSPORE_LITE_SRC_PASS_REDUNDANT_OP_REMOVE_PASS_H_ | |||
| #include <string> | |||
| #include <set> | |||
| #include "backend/optimizer/common/pass.h" | |||
| @@ -24,11 +24,11 @@ | |||
| using mindspore::lite::converter::FmkType; | |||
| namespace mindspore::opt { | |||
| class RemoveIdentityOpPass : public Pass { | |||
| class RemoveRedundantOpPass : public Pass { | |||
| public: | |||
| RemoveIdentityOpPass() : Pass("remove_identity_pass") {} | |||
| ~RemoveIdentityOpPass() override = default; | |||
| int ReplaceIdentity(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager); | |||
| RemoveRedundantOpPass() : Pass("remove_redundant_op_pass") {} | |||
| ~RemoveRedundantOpPass() override = default; | |||
| int ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager); | |||
| int ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager); | |||
| bool Run(const FuncGraphPtr &graph) override; | |||
| @@ -36,4 +36,4 @@ class RemoveIdentityOpPass : public Pass { | |||
| std::set<AnfNodePtr> remove_cnode_; | |||
| }; | |||
| } // namespace mindspore::opt | |||
| #endif // MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_ | |||
| #endif // MINDSPORE_LITE_SRC_PASS_REDUNDANT_OP_REMOVE_PASS_H_ | |||