diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index ac1e262a97..c866d83f25 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -123,8 +123,6 @@ OptimizeIRPassLib::OptimizeIRPassLib() { replace_refkey_by_param_ = MakeSubstitution(std::make_shared(), "replace_refkey_by_param", IsValueNode, opt::FORCE_RENORM); replace_old_param_ = MakeSubstitution(std::make_shared(), "replace_old_param", IsParam); - // Gradient transforms - expand_jprim_ = MakeSubstitution(std::make_shared(), "expand_jprim", prim::kPrimJ); minmaximum_grad_ = MakeSubstitution(std::make_shared(), "minmaximum_grad", prim::kPrimTupleGetItem); // branch culling diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index b3604a2cdb..9ad4ab9ec5 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -85,7 +85,6 @@ class OptimizeIRPassLib { SubstitutionPtr accumulaten_eliminater_; // Gradient irpasses - SubstitutionPtr expand_jprim_; SubstitutionPtr minmaximum_grad_; // inline diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.cc index a7a732b917..3e6bb2c543 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.cc @@ -22,25 +22,28 @@ namespace irpass { namespace internal { AnfNodePtr ExpandJPrimitive(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource) { ScopeGuard scope_guard(vnode->scope()); - auto newg = ad::Kprim(vnode, resource); if (newg != nullptr) { return NewValueNode(newg); } - // when find in J failed, try in Jmeta auto prim = GetValueNode(vnode); MetaFuncGraphPtr meta = ad::Kmeta(prim, resource); if (meta != nullptr) { return NewValueNode(meta); } - return nullptr; } -bool CheckIfEmbedJ(const FuncGraphPtr &func_graph) { - // if func graph also contain J(FuncGraph) or J(Primitive), then ignore this funcgraph. - // ExpandJ innermost graph first. +bool CheckIfEmbedJ(const CNodePtr &j_node) { + auto &value_node = j_node->input(1); + if (IsValueNode(value_node)) { + return false; + } + auto func_graph = GetValueNode(value_node); + if (func_graph == nullptr) { + MS_LOG(EXCEPTION) << "Unexpected j node:" << j_node->DebugString(); + } auto func_graph_manager = func_graph->manager(); MS_EXCEPTION_IF_NULL(func_graph_manager); return func_graph_manager->func_graph_j_total(func_graph); @@ -49,31 +52,48 @@ bool CheckIfEmbedJ(const FuncGraphPtr &func_graph) { AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource) { if (IsValueNode(vnode)) { ScopeGuard scope_guard(vnode->scope()); - auto func_graph = GetValueNode(vnode); - MS_LOG(DEBUG) << "Node is ValueNodeGraph, graph: " << func_graph->ToString(); - - // high_order_grad begin; - // if graph also contains J(FuncGraph) or J(Primitive), then ignore this graph. - // ExpandJ innermost graph or primitive first. - if (CheckIfEmbedJ(func_graph)) { - MS_LOG(DEBUG) << "Funcgraph: " << func_graph->ToString() << " contains J, will expandJ later"; - return nullptr; - } - // high_order_grad end; - MS_LOG(DEBUG) << "Funcgraph: " << func_graph->ToString() << " will expandJ now"; auto newfg = ad::Grad(func_graph, resource); return NewValueNode(newfg); } - if (IsValueNode(vnode)) { return ExpandJPrimitive(vnode, resource); } - return nullptr; } } // namespace internal + +bool ExpandJPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) { + // Search all j nodes. + GetJPrim(optimizer->resource()->manager()); + // Get j nodes that don't have embed j nodes. + std::vector todo; + // If graph also contains J(FuncGraph) or J(Primitive), then ignore this graph. + // ExpandJ innermost graph or primitive first. + std::copy_if(j_nodes_.begin(), j_nodes_.end(), std::back_inserter(todo), + [](const CNodePtr &j_node) { return !internal::CheckIfEmbedJ(j_node); }); + // Expand j nodes that don't have embed j nodes. + bool change = false; + for (auto &j_node : todo) { + auto expanded_j = internal::ExpandJ(j_node->input(1)->cast(), optimizer->resource()); + optimizer->resource()->manager()->Replace(j_node, expanded_j); + change = true; + } + return change; +} + +void ExpandJPrim::GetJPrim(const FuncGraphManagerPtr &manager) { + j_nodes_.clear(); + for (auto &fg : manager->func_graphs()) { + std::vector &&toposet = TopoSort(fg->get_return()); + for (const auto &node : toposet) { + if (IsPrimitiveCNode(node, prim::kPrimJ)) { + j_nodes_.push_back(node->cast()); + } + } + } +} } // namespace irpass } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.h index dde9b66f13..e3dedf6d1f 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.h @@ -31,28 +31,17 @@ namespace mindspore { namespace opt { namespace irpass { -namespace internal { -AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource); -} // namespace internal // {prim::kPrimJ, C} -class ExpandJPrim : public AnfVisitor { +class ExpandJPrim { public: - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { - x_ = nullptr; - AnfVisitor::Match(prim::kPrimJ, {IsVNode})(node); - if (x_ != nullptr) { - TraceGuard guard(std::make_shared(node->debug_info())); - auto j_node = internal::ExpandJ(x_, optimizer->resource()); - return j_node; - } - return nullptr; - } - - void Visit(const ValueNodePtr &node) override { x_ = node; } + ExpandJPrim() = default; + virtual ~ExpandJPrim() = default; + bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer); + void GetJPrim(const FuncGraphManagerPtr &manager); private: - ValueNodePtr x_{nullptr}; + std::vector j_nodes_; }; } // namespace irpass } // namespace opt diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 7b26234f85..9bf1b4cd28 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -41,6 +41,7 @@ #include "utils/log_adapter.h" #include "pipeline/jit/pipeline_split.h" #include "pipeline/jit/static_analysis/auto_monad.h" +#include "frontend/optimizer/irpass/gradient_eliminate.h" #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #include "ps/util.h" #include "ps/ps_context.h" @@ -166,7 +167,6 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { irpass.mini_step_allgather_replace_, }); opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); - opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true); opt::irpass::ResolveIRPassLib resolve_irpass; opt::OptPassConfig resolve_pass = @@ -180,7 +180,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { {"parallel", opt::OptPassConfig(parallel::StepParallel)}, {"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)}, {"virtual_dataset", virtual_dataset}, - {"grad", grad}, + {"grad", opt::OptPassConfig(opt::irpass::ExpandJPrim())}, {"resolve", resolve_pass}, {"a_after_grad", a_after_grad}, {"renormalize", opt::OptPassConfig::Renormalize()}, diff --git a/tests/ut/cpp/optimizer/lib_test.cc b/tests/ut/cpp/optimizer/lib_test.cc index 38881362d5..4f0d3fc6bb 100644 --- a/tests/ut/cpp/optimizer/lib_test.cc +++ b/tests/ut/cpp/optimizer/lib_test.cc @@ -88,14 +88,6 @@ class TestOptLib : public UT::Common { irpass::OptimizeIRPassLib irpass; }; -TEST_F(TestOptLib, test_expendJ) { - FuncGraphPtr before = getPyFun("test_expendJ"); - - ASSERT_TRUE(nullptr != before); - - FuncGraphPtr after = RunSubs(before, std::vector({irpass.expand_jprim_})); -} - TEST_F(TestOptLib, test_simplify_always_true_false) { FuncGraphPtr before1 = getPyFun.CallAndParseRet("test_simplify_always_true_false", "before_1"); FuncGraphPtr before2 = getPyFun.CallAndParseRet("test_simplify_always_true_false", "before_2"); diff --git a/tests/ut/cpp/optimizer/optimizer_test.cc b/tests/ut/cpp/optimizer/optimizer_test.cc index 6e9d04f6cf..a5e119cc27 100644 --- a/tests/ut/cpp/optimizer/optimizer_test.cc +++ b/tests/ut/cpp/optimizer/optimizer_test.cc @@ -24,6 +24,7 @@ #include "frontend/optimizer/cse_pass.h" #include "frontend/optimizer/optimizer.h" #include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/irpass/gradient_eliminate.h" #include "debug/draw.h" namespace mindspore { @@ -38,23 +39,24 @@ class TestOptOptimizer : public UT::Common { }; TEST_F(TestOptOptimizer, test_step_opt) { - FuncGraphPtr before = getPyFun("test_expendJ"); + FuncGraphPtr before = getPyFun("test_expandJ"); ASSERT_TRUE(nullptr != before); pipeline::ResourcePtr res = std::make_shared(); - std::shared_ptr optimizer = Optimizer::MakeOptimizer("ut_test", res, - {{"main", - { - // Branch culling - irpass.switch_simplify_, + std::shared_ptr optimizer = + Optimizer::MakeOptimizer("ut_test", res, + {{"main", + { + // Branch culling + irpass.switch_simplify_, - // Safe inlining - irpass.arithmetic_simplify_, - irpass.inline_, - }}, - {"grad", {irpass.expand_jprim_}}, - {"cse", OptPassConfig(CSEPass(false))}}, - true); + // Safe inlining + irpass.arithmetic_simplify_, + irpass.inline_, + }}, + {"grad", opt::OptPassConfig(opt::irpass::ExpandJPrim())}, + {"cse", OptPassConfig(CSEPass(false))}}, + true); EXPECT_TRUE(optimizer.get() != nullptr); auto after = optimizer->step(before); diff --git a/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py b/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py index c2c773c2e5..e91f01da33 100644 --- a/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py +++ b/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py @@ -133,8 +133,8 @@ def cost(x): J = Primitive('J') -def test_expendJ(x): - """ test_expendJ """ +def test_expandJ(x): + """ test_expandJ """ return J(cost)(x)