| @@ -225,7 +225,7 @@ if(ENABLE_CONVERTER) | |||||
| ${LITE_DIR}/tools/optimizer/graph/update_conv2d_param_pass.cc | ${LITE_DIR}/tools/optimizer/graph/update_conv2d_param_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc | ${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/unused_transpose_node_remove_pass.cc | ${LITE_DIR}/tools/optimizer/graph/unused_transpose_node_remove_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/identity_remove_pass.cc | |||||
| ${LITE_DIR}/tools/optimizer/graph/redundant_op_remove_pass.cc | |||||
| ${LITE_DIR}/tools/optimizer/graph/infershape_pass.cc | ${LITE_DIR}/tools/optimizer/graph/infershape_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc | ${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/mindir_adjust_pass.cc | ${LITE_DIR}/tools/optimizer/graph/mindir_adjust_pass.cc | ||||
| @@ -59,41 +59,6 @@ void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) { | |||||
| } | } | ||||
| } | } | ||||
| void AnfExporter::RemoveIfDepend(const CNodePtr &cnode) { | |||||
| bool hasDepend = false; | |||||
| std::vector<AnfNodePtr> inputs; | |||||
| inputs.clear(); | |||||
| inputs.emplace_back(cnode->input(0)); | |||||
| for (size_t i = 1; i < cnode->inputs().size(); ++i) { | |||||
| AnfNodePtr inputNode = cnode->input(i); | |||||
| if (!inputNode->isa<CNode>()) { | |||||
| inputs.emplace_back(cnode->input(i)); | |||||
| continue; | |||||
| } | |||||
| auto dependNode = utils::cast<CNodePtr>(inputNode); | |||||
| if (IsPrimitiveCNode(dependNode, schema::PrimitiveType_Depend) || | |||||
| IsPrimitiveCNode(dependNode, schema::PrimitiveType_ControlDepend)) { | |||||
| hasDepend = true; | |||||
| bool maskOut = (dependNode->inputs().size() == 3); | |||||
| for (size_t j = 1; j < dependNode->inputs().size(); ++j) { | |||||
| AnfNodePtr dependInputNode = dependNode->input(j); | |||||
| if (dependInputNode->isa<CNode>()) { | |||||
| inputs.emplace_back(dependInputNode); | |||||
| if (maskOut) { | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| } else { | |||||
| inputs.emplace_back(cnode->input(i)); | |||||
| } | |||||
| } | |||||
| if (hasDepend) { | |||||
| cnode->set_inputs(inputs); | |||||
| } | |||||
| } | |||||
| int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph, | int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph, | ||||
| const std::shared_ptr<PrimitiveC> &primitive, | const std::shared_ptr<PrimitiveC> &primitive, | ||||
| const std::unique_ptr<schema::CNodeT> &dst_node) { | const std::unique_ptr<schema::CNodeT> &dst_node) { | ||||
| @@ -286,23 +251,11 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc | |||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| #ifdef SUPPORT_TRAIN | |||||
| RemoveIfMakeTuple(cnode); | RemoveIfMakeTuple(cnode); | ||||
| RemoveIfDepend(cnode); | |||||
| #endif | |||||
| if ((primitive_c->Type() == schema::PrimitiveType_TupleGetItem) || | if ((primitive_c->Type() == schema::PrimitiveType_TupleGetItem) || | ||||
| #ifdef SUPPORT_TRAIN | |||||
| (primitive_c->Type() == schema::PrimitiveType_Depend) || | |||||
| (primitive_c->Type() == schema::PrimitiveType_ControlDepend) || | |||||
| #endif | |||||
| (primitive_c->Type() == schema::PrimitiveType_MakeTuple)) { | (primitive_c->Type() == schema::PrimitiveType_MakeTuple)) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| #ifndef SUPPORT_TRAIN | |||||
| RemoveIfMakeTuple(cnode); | |||||
| #endif | |||||
| auto primT = primitive_c->primitiveT(); | auto primT = primitive_c->primitiveT(); | ||||
| auto node = std::make_unique<schema::CNodeT>(); | auto node = std::make_unique<schema::CNodeT>(); | ||||
| if (node == nullptr) { | if (node == nullptr) { | ||||
| @@ -41,7 +41,6 @@ class AnfExporter { | |||||
| int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | ||||
| schema::CNodeT *fb_node); | schema::CNodeT *fb_node); | ||||
| static void RemoveIfMakeTuple(const CNodePtr &cnode); | static void RemoveIfMakeTuple(const CNodePtr &cnode); | ||||
| static void RemoveIfDepend(const CNodePtr &cnode); | |||||
| protected: | protected: | ||||
| int ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, schema::CNodeT *output_cnode); | int ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, schema::CNodeT *output_cnode); | ||||
| @@ -59,7 +59,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| ../optimizer/graph/update_conv2d_param_pass.cc | ../optimizer/graph/update_conv2d_param_pass.cc | ||||
| ../optimizer/graph/unused_cast_node_remove_pass.cc | ../optimizer/graph/unused_cast_node_remove_pass.cc | ||||
| ../optimizer/graph/unused_transpose_node_remove_pass.cc | ../optimizer/graph/unused_transpose_node_remove_pass.cc | ||||
| ../optimizer/graph/identity_remove_pass.cc | |||||
| ../optimizer/graph/redundant_op_remove_pass.cc | |||||
| ../optimizer/graph/infershape_pass.cc | ../optimizer/graph/infershape_pass.cc | ||||
| ../optimizer/graph/slice_prepose_pass.cc | ../optimizer/graph/slice_prepose_pass.cc | ||||
| ../optimizer/graph/mindir_adjust_pass.cc | ../optimizer/graph/mindir_adjust_pass.cc | ||||
| @@ -34,7 +34,7 @@ | |||||
| #include "tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h" | #include "tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h" | ||||
| #include "tools/optimizer/graph/mindir_adjust_pass.h" | #include "tools/optimizer/graph/mindir_adjust_pass.h" | ||||
| #include "tools/optimizer/graph/mindir_inputs_adjust_pass.h" | #include "tools/optimizer/graph/mindir_inputs_adjust_pass.h" | ||||
| #include "tools/optimizer/graph/identity_remove_pass.h" | |||||
| #include "tools/optimizer/graph/redundant_op_remove_pass.h" | |||||
| #include "tools/optimizer/graph/weight_format_hardcode_pass.h" | #include "tools/optimizer/graph/weight_format_hardcode_pass.h" | ||||
| #include "tools/optimizer/graph/weight_format_transform_pass.h" | #include "tools/optimizer/graph/weight_format_transform_pass.h" | ||||
| #include "tools/optimizer/graph/clip_convert_activation_pass.h" | #include "tools/optimizer/graph/clip_convert_activation_pass.h" | ||||
| @@ -144,7 +144,7 @@ int AnfTransform::AddConvertPass(const std::shared_ptr<opt::GraphOptimizer> &opt | |||||
| int AnfTransform::AddConstFoldPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, | int AnfTransform::AddConstFoldPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, | ||||
| const converter::Flags *config) { | const converter::Flags *config) { | ||||
| auto const_fold_pm = std::make_shared<opt::PassManager>("const fold fusion pass manager", false); | auto const_fold_pm = std::make_shared<opt::PassManager>("const fold fusion pass manager", false); | ||||
| const_fold_pm->AddPass(std::make_shared<opt::RemoveIdentityOpPass>()); | |||||
| const_fold_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>()); | |||||
| if (!config->trainModel) { | if (!config->trainModel) { | ||||
| auto inne_context_ptr = std::make_shared<lite::InnerContext>(); | auto inne_context_ptr = std::make_shared<lite::InnerContext>(); | ||||
| inne_context_ptr->Init(); | inne_context_ptr->Init(); | ||||
| @@ -13,37 +13,41 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "tools/optimizer/graph/identity_remove_pass.h" | |||||
| #include "tools/optimizer/graph/redundant_op_remove_pass.h" | |||||
| #include <memory> | |||||
| #include "mindspore/lite/include/errorcode.h" | #include "mindspore/lite/include/errorcode.h" | ||||
| #include "src/ops/primitive_c.h" | #include "src/ops/primitive_c.h" | ||||
| namespace mindspore::opt { | namespace mindspore::opt { | ||||
| int RemoveIdentityOpPass::ReplaceIdentity(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) { | |||||
| namespace { | |||||
| constexpr size_t InputDoubleNum = 2; | |||||
| constexpr size_t InputTripleNum = 3; | |||||
| constexpr auto kNameLoad = "Load"; | |||||
| constexpr auto kNameUpdateState = "UpdateState"; | |||||
| } // namespace | |||||
| int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) { | |||||
| if (!utils::isa<CNodePtr>(anf_node)) { | if (!utils::isa<CNodePtr>(anf_node)) { | ||||
| MS_LOG(DEBUG) << "anf node is node a cnode."; | MS_LOG(DEBUG) << "anf node is node a cnode."; | ||||
| return lite::RET_NO_CHANGE; | return lite::RET_NO_CHANGE; | ||||
| } | } | ||||
| auto type = opt::GetCNodeType(anf_node); | auto type = opt::GetCNodeType(anf_node); | ||||
| if (type != schema::PrimitiveType_Identity) { | |||||
| MS_LOG(DEBUG) << "anf node is not a identity node."; | |||||
| return lite::RET_NO_CHANGE; | |||||
| } | |||||
| auto identity_cnode = anf_node->cast<CNodePtr>(); | |||||
| if (identity_cnode->inputs().size() != lite::kDoubleNum) { | |||||
| MS_LOG(DEBUG) << "The node inputs size is bigger than 1"; | |||||
| remove_cnode_.insert(anf_node); | |||||
| return lite::RET_NO_CHANGE; | |||||
| } else { | |||||
| bool replace_succ = manager->Replace(anf_node, identity_cnode->input(1)); | |||||
| if (!replace_succ) { | |||||
| MS_LOG(ERROR) << "replace identity failed."; | |||||
| return lite::RET_ERROR; | |||||
| auto cnode = anf_node->cast<CNodePtr>(); | |||||
| if (type == schema::PrimitiveType_Identity) { | |||||
| if (cnode->size() != InputDoubleNum) { | |||||
| MS_LOG(DEBUG) << "The node inputs size is bigger than 1"; | |||||
| remove_cnode_.insert(anf_node); | |||||
| return lite::RET_NO_CHANGE; | |||||
| } | } | ||||
| } | } | ||||
| bool replace_succ = manager->Replace(anf_node, cnode->input(1)); | |||||
| if (!replace_succ) { | |||||
| MS_LOG(ERROR) << "replace redundant op failed."; | |||||
| return lite::RET_ERROR; | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int RemoveIdentityOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) { | |||||
| int RemoveRedundantOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) { | |||||
| if (!utils::isa<CNodePtr>(anf_node)) { | if (!utils::isa<CNodePtr>(anf_node)) { | ||||
| MS_LOG(DEBUG) << "anf node is node a cnode."; | MS_LOG(DEBUG) << "anf node is node a cnode."; | ||||
| return lite::RET_NO_CHANGE; | return lite::RET_NO_CHANGE; | ||||
| @@ -53,7 +57,7 @@ int RemoveIdentityOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const | |||||
| return lite::RET_NO_CHANGE; | return lite::RET_NO_CHANGE; | ||||
| } | } | ||||
| auto cnode = anf_node->cast<CNodePtr>(); | auto cnode = anf_node->cast<CNodePtr>(); | ||||
| if (cnode->inputs().size() != 3) { | |||||
| if (cnode->inputs().size() != InputTripleNum) { | |||||
| MS_LOG(ERROR) << "TupleGetItem should have 3 inputs, got " << cnode->inputs().size(); | MS_LOG(ERROR) << "TupleGetItem should have 3 inputs, got " << cnode->inputs().size(); | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -81,7 +85,7 @@ int RemoveIdentityOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const | |||||
| return lite::RET_OK; | return lite::RET_OK; | ||||
| } | } | ||||
| bool RemoveIdentityOpPass::Run(const FuncGraphPtr &func_graph) { | |||||
| bool RemoveRedundantOpPass::Run(const FuncGraphPtr &func_graph) { | |||||
| MS_ASSERT(func_graph != nullptr); | MS_ASSERT(func_graph != nullptr); | ||||
| auto manager = func_graph->manager(); | auto manager = func_graph->manager(); | ||||
| MS_ASSERT(manager != nullptr); | MS_ASSERT(manager != nullptr); | ||||
| @@ -93,10 +97,22 @@ bool RemoveIdentityOpPass::Run(const FuncGraphPtr &func_graph) { | |||||
| } | } | ||||
| auto type = opt::GetCNodeType(node); | auto type = opt::GetCNodeType(node); | ||||
| if (type == schema::PrimitiveType_Identity) { | if (type == schema::PrimitiveType_Identity) { | ||||
| status = ReplaceIdentity(node, manager); | |||||
| } else if (type == schema::PrimitiveType_TupleGetItem) { | |||||
| status = ReplaceOp(node, manager); | |||||
| } | |||||
| if (CheckPrimitiveType(node, std::make_shared<Primitive>(kNameLoad))) { | |||||
| status = ReplaceOp(node, manager); | |||||
| } | |||||
| if (CheckPrimitiveType(node, std::make_shared<Primitive>(kNameUpdateState))) { | |||||
| status = ReplaceOp(node, manager); | |||||
| } | |||||
| if (type == schema::PrimitiveType_Depend || | |||||
| type == schema::PrimitiveType_ControlDepend) { // ControlDepend delete next version. | |||||
| status = ReplaceOp(node, manager); | |||||
| } | |||||
| if (type == schema::PrimitiveType_TupleGetItem) { | |||||
| status = ReplaceTupleGetItem(node, manager); | status = ReplaceTupleGetItem(node, manager); | ||||
| } else if (type == schema::PrimitiveType_If || type == schema::PrimitiveType_While) { | |||||
| } | |||||
| if (type == schema::PrimitiveType_If || type == schema::PrimitiveType_While) { | |||||
| auto sub_func_graph = GetValueNode<FuncGraphPtr>(node->cast<CNodePtr>()->input(1)); | auto sub_func_graph = GetValueNode<FuncGraphPtr>(node->cast<CNodePtr>()->input(1)); | ||||
| if (sub_func_graph == nullptr) { | if (sub_func_graph == nullptr) { | ||||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_ | |||||
| #define MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_ | |||||
| #ifndef MINDSPORE_LITE_SRC_PASS_REDUNDANT_OP_REMOVE_PASS_H_ | |||||
| #define MINDSPORE_LITE_SRC_PASS_REDUNDANT_OP_REMOVE_PASS_H_ | |||||
| #include <string> | #include <string> | ||||
| #include <set> | #include <set> | ||||
| #include "backend/optimizer/common/pass.h" | #include "backend/optimizer/common/pass.h" | ||||
| @@ -24,11 +24,11 @@ | |||||
| using mindspore::lite::converter::FmkType; | using mindspore::lite::converter::FmkType; | ||||
| namespace mindspore::opt { | namespace mindspore::opt { | ||||
| class RemoveIdentityOpPass : public Pass { | |||||
| class RemoveRedundantOpPass : public Pass { | |||||
| public: | public: | ||||
| RemoveIdentityOpPass() : Pass("remove_identity_pass") {} | |||||
| ~RemoveIdentityOpPass() override = default; | |||||
| int ReplaceIdentity(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager); | |||||
| RemoveRedundantOpPass() : Pass("remove_redundant_op_pass") {} | |||||
| ~RemoveRedundantOpPass() override = default; | |||||
| int ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager); | |||||
| int ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager); | int ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager); | ||||
| bool Run(const FuncGraphPtr &graph) override; | bool Run(const FuncGraphPtr &graph) override; | ||||
| @@ -36,4 +36,4 @@ class RemoveIdentityOpPass : public Pass { | |||||
| std::set<AnfNodePtr> remove_cnode_; | std::set<AnfNodePtr> remove_cnode_; | ||||
| }; | }; | ||||
| } // namespace mindspore::opt | } // namespace mindspore::opt | ||||
| #endif // MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_ | |||||
| #endif // MINDSPORE_LITE_SRC_PASS_REDUNDANT_OP_REMOVE_PASS_H_ | |||||