/** * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" #include "ir/anf.h" #include "ir/visitor.h" #include "ir/func_graph_cloner.h" #include "frontend/optimizer/optimizer.h" #include "frontend/optimizer/opt.h" #include "frontend/optimizer/anf_visitor.h" #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/irpass/arithmetic_simplify.h" #include "pipeline/jit/action.h" #include "include/common/debug/draw.h" #include "frontend/operator/ops.h" #include "include/common/utils/cse.h" #include "include/common/utils/convert_utils.h" namespace mindspore { namespace opt { class TestOptOpt : public UT::Common { public: TestOptOpt() : getPyFun("gtest_input.optimizer.opt_test", true) {} class IdempotentEliminater : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { x_ = nullptr; AnfVisitor::Match(P, {irpass::IsCNode})(node); if (x_ == nullptr || node->func_graph() == nullptr) { return nullptr; } return node->func_graph()->NewCNode({NewValueNode(P), x_}); }; void Visit(const CNodePtr &cnode) override { if (IsPrimitiveCNode(cnode, P) && cnode->inputs().size() == 2) { x_ = cnode->input(1); } } private: AnfNodePtr x_{nullptr}; }; class QctToP : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { v_ = nullptr; AnfVisitor::Match(Q, {irpass::IsVNode})(node); if (v_ == nullptr || node->func_graph() == nullptr) { return nullptr; } return node->func_graph()->NewCNode({NewValueNode(P), v_}); }; void Visit(const ValueNodePtr &vnode) override { v_ = vnode; } private: AnfNodePtr v_{nullptr}; }; void SetUp() { elim_Z = MakeSubstitution(std::make_shared(), "elim_Z", prim::kPrimScalarAdd); elim_R = MakeSubstitution(std::make_shared(R), "elim_R", R); idempotent_P = MakeSubstitution(std::make_shared(), "idempotent_P", P); Qct_to_P = MakeSubstitution(std::make_shared(), "Qct_to_P", Q); } bool CheckTransform(FuncGraphPtr gbefore, FuncGraphPtr gafter, const SubstitutionList &transform) { equiv_node.clear(); equiv_graph.clear(); FuncGraphPtr gbefore_clone = BasicClone(gbefore); OptimizerPtr optimizer = std::make_shared("ut_test", std::make_shared()); transform(gbefore_clone, optimizer); return Isomorphic(gbefore_clone, gafter, &equiv_graph, &equiv_node); } bool CheckOpt(FuncGraphPtr before, FuncGraphPtr after, std::vector opts = {}) { SubstitutionList eq(opts); return CheckTransform(before, after, eq); } public: UT::PyFuncGraphFetcher getPyFun; FuncGraphPairMapEquiv equiv_graph; NodeMapEquiv equiv_node; irpass::OptimizeIRPassLib irpass_lib; static const PrimitivePtr P; static const PrimitivePtr Q; static const PrimitivePtr R; SubstitutionPtr elim_Z; SubstitutionPtr elim_R; SubstitutionPtr idempotent_P; SubstitutionPtr Qct_to_P; SubstitutionPtr tuple_flatten = irpass_lib.call_graph_tuple_transform_; }; const PrimitivePtr TestOptOpt::P = std::make_shared("P"); const PrimitivePtr TestOptOpt::Q = std::make_shared("Q"); const PrimitivePtr TestOptOpt::R = std::make_shared("R"); TEST_F(TestOptOpt, TestCheckOptIsClone) { FuncGraphPtr before = getPyFun.CallAndParseRet("test_add_zero", "before_1"); ASSERT_TRUE(nullptr != before); ASSERT_TRUE(CheckOpt(before, before)); ASSERT_FALSE(CheckOpt(before, before, std::vector({elim_Z}))); } TEST_F(TestOptOpt, Elim) { FuncGraphPtr before = getPyFun.CallAndParseRet("test_add_zero", "before_1"); FuncGraphPtr after = getPyFun.CallAndParseRet("test_add_zero", "after"); ASSERT_TRUE(nullptr != before); ASSERT_TRUE(nullptr != after); ASSERT_TRUE(CheckOpt(before, after, std::vector({elim_Z}))); } TEST_F(TestOptOpt, ElimTwo) { FuncGraphPtr before = getPyFun.CallAndParseRet("test_add_zero", "before_2"); FuncGraphPtr after = getPyFun.CallAndParseRet("test_add_zero", "after"); ASSERT_TRUE(nullptr != before); ASSERT_TRUE(nullptr != after); ASSERT_TRUE(CheckOpt(before, after, std::vector({elim_Z}))); } TEST_F(TestOptOpt, ElimR) { FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_r", "before_1"); FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_r", "after"); ASSERT_TRUE(nullptr != before); ASSERT_TRUE(nullptr != after); ASSERT_TRUE(CheckOpt(before, after, std::vector({elim_R}))); } TEST_F(TestOptOpt, idempotent) { FuncGraphPtr before_2 = getPyFun.CallAndParseRet("test_idempotent", "before_2"); FuncGraphPtr before_1 = getPyFun.CallAndParseRet("test_idempotent", "before_1"); FuncGraphPtr after = getPyFun.CallAndParseRet("test_idempotent", "after"); ASSERT_TRUE(nullptr != before_2); ASSERT_TRUE(nullptr != before_1); ASSERT_TRUE(nullptr != after); ASSERT_TRUE(CheckOpt(before_1, after, std::vector({idempotent_P}))); ASSERT_TRUE(CheckOpt(before_2, after, std::vector({idempotent_P}))); } TEST_F(TestOptOpt, ConstantVariable) { FuncGraphPtr before = getPyFun.CallAndParseRet("test_constant_variable", "before_1"); FuncGraphPtr after = getPyFun.CallAndParseRet("test_constant_variable", "after"); ASSERT_TRUE(nullptr != before); ASSERT_TRUE(nullptr != after); ASSERT_TRUE(CheckOpt(before, after, std::vector({Qct_to_P}))); } TEST_F(TestOptOpt, CSE) { // test a simple cse testcase test_f1 FuncGraphPtr test_graph1 = getPyFun.CallAndParseRet("test_cse", "test_f1"); ASSERT_TRUE(nullptr != test_graph1); // add func_graph the GraphManager FuncGraphManagerPtr manager1 = Manage(test_graph1); ASSERT_EQ(manager1->all_nodes().size(), 9); auto cse = std::make_shared(); ASSERT_TRUE(cse != nullptr); bool is_changed = cse->Cse(test_graph1, manager1); ASSERT_TRUE(is_changed); ASSERT_EQ(manager1->all_nodes().size(), 8); // test a more complicated case test_f2 FuncGraphPtr test_graph2 = getPyFun.CallAndParseRet("test_cse", "test_f2"); ASSERT_TRUE(nullptr != test_graph2); FuncGraphManagerPtr manager2 = Manage(test_graph2); ASSERT_EQ(manager2->all_nodes().size(), 16); is_changed = cse->Cse(test_graph2, manager2); ASSERT_TRUE(is_changed); ASSERT_EQ(manager2->all_nodes().size(), 12); } size_t TupleArgAndParamSum(const FuncGraphPtr &func_graph) { // Check tuple params and tuple args. auto all_nodes = TopoSort(func_graph->return_node(), SuccDeeperSimple, AlwaysInclude); size_t tuple_arg_param_num = 0; auto tuple_accumulate_func = [](size_t prev_num, const AnfNodePtr &node) -> size_t { auto abs = node->abstract(); MS_EXCEPTION_IF_NULL(abs); return abs->isa() ? prev_num + 1 : prev_num; }; for (const auto &node : all_nodes) { // Count func graph call tuple args. if (node->isa() && !IsValueNode(node->cast()->input(0))) { auto call_node = node->cast(); tuple_arg_param_num = std::accumulate(call_node->inputs().begin() + 1, call_node->inputs().end(), tuple_arg_param_num, tuple_accumulate_func); } // Count partial tuple args. if (IsPrimitiveCNode(node, prim::kPrimPartial)) { auto partial = node->cast(); constexpr auto kPartialFirstArgIdx = 2; tuple_arg_param_num = std::accumulate(partial->inputs().begin() + kPartialFirstArgIdx, partial->inputs().end(), tuple_arg_param_num, tuple_accumulate_func); } // Count tuple params. if (IsValueNode(node)) { auto fg = GetValueNode(node); tuple_arg_param_num = std::accumulate(fg->parameters().begin(), fg->parameters().end(), tuple_arg_param_num, tuple_accumulate_func); } } return tuple_arg_param_num; } // Feature: Switch call tuple arg transform. // Description: Test switch call's tuple arg transform.This case include partial's tuple arg and the call's tuple arg in // the same time. // Expectation: All tuple args are correctly transformed to tensor args. TEST_F(TestOptOpt, SwitchPartialTupleTrans) { FuncGraphPtr test_graph = getPyFun.CallAndParseRet("test_tuple_flatten", "test_flatten_switch_partial_arg"); ASSERT_TRUE(nullptr != test_graph); FuncGraphManagerPtr manager1 = Manage(test_graph); pipeline::ResourcePtr res = std::make_shared(); std::vector args_spec; // Renormalize firstly. auto renormalized_fg = pipeline::Renormalize(res, test_graph, args_spec); ASSERT_TRUE(TupleArgAndParamSum(renormalized_fg) != 0); // Flatten tuple param and args. OptimizerPtr optimizer = std::make_shared("ut_test", res); SubstitutionList transform(std::vector({tuple_flatten})); transform(renormalized_fg, optimizer); // Renormalize again. auto transformed_fg = pipeline::Renormalize(res, renormalized_fg, args_spec); ASSERT_TRUE(TupleArgAndParamSum(transformed_fg) == 0); abstract::AnalysisResultCacheMgr::GetInstance().Clear(); abstract::AnalysisContext::ClearContext(); } // Feature: Switch layer call tuple arg transform. // Description: Test switch layer call's tuple arg transform.This case include partial's tuple arg and the partial's // tensor arg in the same time. // Expectation: All tuple args are correctly transformed to tensor args. TEST_F(TestOptOpt, SwitchLayerPartialTupleTrans) { FuncGraphPtr test_graph = getPyFun.CallAndParseRet("test_tuple_flatten", "test_flatten_switch_layer_partial_arg"); ASSERT_TRUE(nullptr != test_graph); FuncGraphManagerPtr manager1 = Manage(test_graph); pipeline::ResourcePtr res = std::make_shared(); std::vector args_spec; // Renormalize firstly. auto renormalized_fg = pipeline::Renormalize(res, test_graph, args_spec); ASSERT_TRUE(TupleArgAndParamSum(renormalized_fg) != 0); // Flatten tuple param and args. OptimizerPtr optimizer = std::make_shared("ut_test", res); SubstitutionList transform(std::vector({tuple_flatten})); transform(renormalized_fg, optimizer); // Renormalize again. auto transformed_fg = pipeline::Renormalize(res, renormalized_fg, args_spec); ASSERT_TRUE(TupleArgAndParamSum(transformed_fg) == 0); abstract::AnalysisResultCacheMgr::GetInstance().Clear(); abstract::AnalysisContext::ClearContext(); } // Feature: Single graph call tuple arg transform. // Description: Test single graph call's tuple arg transform.This case include tuple in tuple args. // Expectation: All tuple args are correctly transformed to tensor args. TEST_F(TestOptOpt, SimpleCallTupleTupleTrans) { FuncGraphPtr test_graph = getPyFun.CallAndParseRet("test_tuple_flatten", "test_flatten_simple_call_tuple_in_tuple_arg"); ASSERT_TRUE(nullptr != test_graph); FuncGraphManagerPtr manager1 = Manage(test_graph); pipeline::ResourcePtr res = std::make_shared(); std::vector args_spec; // Renormalize firstly. auto renormalized_fg = pipeline::Renormalize(res, test_graph, args_spec); ASSERT_TRUE(TupleArgAndParamSum(renormalized_fg) != 0); // Flatten tuple param and args. OptimizerPtr optimizer = std::make_shared("ut_test", res); SubstitutionList transform(std::vector({tuple_flatten})); transform(renormalized_fg, optimizer); // Renormalize again. auto transformed_fg = pipeline::Renormalize(res, renormalized_fg, args_spec); ASSERT_TRUE(TupleArgAndParamSum(transformed_fg) == 0); abstract::AnalysisResultCacheMgr::GetInstance().Clear(); abstract::AnalysisContext::ClearContext(); } } // namespace opt } // namespace mindspore