From: @chenfei52 Reviewed-by: @zh_qh Signed-off-by: @zh_qhtags/v1.2.0-rc1
| @@ -123,8 +123,6 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||||
| replace_refkey_by_param_ = MakeSubstitution(std::make_shared<ReplaceRefkeyByParam>(), "replace_refkey_by_param", | replace_refkey_by_param_ = MakeSubstitution(std::make_shared<ReplaceRefkeyByParam>(), "replace_refkey_by_param", | ||||
| IsValueNode<RefKey>, opt::FORCE_RENORM); | IsValueNode<RefKey>, opt::FORCE_RENORM); | ||||
| replace_old_param_ = MakeSubstitution(std::make_shared<ReplaceOldParam>(), "replace_old_param", IsParam); | replace_old_param_ = MakeSubstitution(std::make_shared<ReplaceOldParam>(), "replace_old_param", IsParam); | ||||
| // Gradient transforms | |||||
| expand_jprim_ = MakeSubstitution(std::make_shared<ExpandJPrim>(), "expand_jprim", prim::kPrimJ); | |||||
| minmaximum_grad_ = MakeSubstitution(std::make_shared<MinMaximumGrad>(), "minmaximum_grad", prim::kPrimTupleGetItem); | minmaximum_grad_ = MakeSubstitution(std::make_shared<MinMaximumGrad>(), "minmaximum_grad", prim::kPrimTupleGetItem); | ||||
| // branch culling | // branch culling | ||||
| @@ -85,7 +85,6 @@ class OptimizeIRPassLib { | |||||
| SubstitutionPtr accumulaten_eliminater_; | SubstitutionPtr accumulaten_eliminater_; | ||||
| // Gradient irpasses | // Gradient irpasses | ||||
| SubstitutionPtr expand_jprim_; | |||||
| SubstitutionPtr minmaximum_grad_; | SubstitutionPtr minmaximum_grad_; | ||||
| // inline | // inline | ||||
| @@ -22,25 +22,28 @@ namespace irpass { | |||||
| namespace internal { | namespace internal { | ||||
| AnfNodePtr ExpandJPrimitive(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource) { | AnfNodePtr ExpandJPrimitive(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource) { | ||||
| ScopeGuard scope_guard(vnode->scope()); | ScopeGuard scope_guard(vnode->scope()); | ||||
| auto newg = ad::Kprim(vnode, resource); | auto newg = ad::Kprim(vnode, resource); | ||||
| if (newg != nullptr) { | if (newg != nullptr) { | ||||
| return NewValueNode(newg); | return NewValueNode(newg); | ||||
| } | } | ||||
| // when find in J failed, try in Jmeta | // when find in J failed, try in Jmeta | ||||
| auto prim = GetValueNode<PrimitivePtr>(vnode); | auto prim = GetValueNode<PrimitivePtr>(vnode); | ||||
| MetaFuncGraphPtr meta = ad::Kmeta(prim, resource); | MetaFuncGraphPtr meta = ad::Kmeta(prim, resource); | ||||
| if (meta != nullptr) { | if (meta != nullptr) { | ||||
| return NewValueNode(meta); | return NewValueNode(meta); | ||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| bool CheckIfEmbedJ(const FuncGraphPtr &func_graph) { | |||||
| // if func graph also contain J(FuncGraph) or J(Primitive), then ignore this funcgraph. | |||||
| // ExpandJ innermost graph first. | |||||
| bool CheckIfEmbedJ(const CNodePtr &j_node) { | |||||
| auto &value_node = j_node->input(1); | |||||
| if (IsValueNode<Primitive>(value_node)) { | |||||
| return false; | |||||
| } | |||||
| auto func_graph = GetValueNode<FuncGraphPtr>(value_node); | |||||
| if (func_graph == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Unexpected j node:" << j_node->DebugString(); | |||||
| } | |||||
| auto func_graph_manager = func_graph->manager(); | auto func_graph_manager = func_graph->manager(); | ||||
| MS_EXCEPTION_IF_NULL(func_graph_manager); | MS_EXCEPTION_IF_NULL(func_graph_manager); | ||||
| return func_graph_manager->func_graph_j_total(func_graph); | return func_graph_manager->func_graph_j_total(func_graph); | ||||
| @@ -49,31 +52,48 @@ bool CheckIfEmbedJ(const FuncGraphPtr &func_graph) { | |||||
| AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource) { | AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource) { | ||||
| if (IsValueNode<FuncGraph>(vnode)) { | if (IsValueNode<FuncGraph>(vnode)) { | ||||
| ScopeGuard scope_guard(vnode->scope()); | ScopeGuard scope_guard(vnode->scope()); | ||||
| auto func_graph = GetValueNode<FuncGraphPtr>(vnode); | auto func_graph = GetValueNode<FuncGraphPtr>(vnode); | ||||
| MS_LOG(DEBUG) << "Node is ValueNodeGraph, graph: " << func_graph->ToString(); | |||||
| // high_order_grad begin; | |||||
| // if graph also contains J(FuncGraph) or J(Primitive), then ignore this graph. | |||||
| // ExpandJ innermost graph or primitive first. | |||||
| if (CheckIfEmbedJ(func_graph)) { | |||||
| MS_LOG(DEBUG) << "Funcgraph: " << func_graph->ToString() << " contains J, will expandJ later"; | |||||
| return nullptr; | |||||
| } | |||||
| // high_order_grad end; | |||||
| MS_LOG(DEBUG) << "Funcgraph: " << func_graph->ToString() << " will expandJ now"; | MS_LOG(DEBUG) << "Funcgraph: " << func_graph->ToString() << " will expandJ now"; | ||||
| auto newfg = ad::Grad(func_graph, resource); | auto newfg = ad::Grad(func_graph, resource); | ||||
| return NewValueNode(newfg); | return NewValueNode(newfg); | ||||
| } | } | ||||
| if (IsValueNode<Primitive>(vnode)) { | if (IsValueNode<Primitive>(vnode)) { | ||||
| return ExpandJPrimitive(vnode, resource); | return ExpandJPrimitive(vnode, resource); | ||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| } // namespace internal | } // namespace internal | ||||
| bool ExpandJPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) { | |||||
| // Search all j nodes. | |||||
| GetJPrim(optimizer->resource()->manager()); | |||||
| // Get j nodes that don't have embed j nodes. | |||||
| std::vector<CNodePtr> todo; | |||||
| // If graph also contains J(FuncGraph) or J(Primitive), then ignore this graph. | |||||
| // ExpandJ innermost graph or primitive first. | |||||
| std::copy_if(j_nodes_.begin(), j_nodes_.end(), std::back_inserter(todo), | |||||
| [](const CNodePtr &j_node) { return !internal::CheckIfEmbedJ(j_node); }); | |||||
| // Expand j nodes that don't have embed j nodes. | |||||
| bool change = false; | |||||
| for (auto &j_node : todo) { | |||||
| auto expanded_j = internal::ExpandJ(j_node->input(1)->cast<ValueNodePtr>(), optimizer->resource()); | |||||
| optimizer->resource()->manager()->Replace(j_node, expanded_j); | |||||
| change = true; | |||||
| } | |||||
| return change; | |||||
| } | |||||
| void ExpandJPrim::GetJPrim(const FuncGraphManagerPtr &manager) { | |||||
| j_nodes_.clear(); | |||||
| for (auto &fg : manager->func_graphs()) { | |||||
| std::vector<AnfNodePtr> &&toposet = TopoSort(fg->get_return()); | |||||
| for (const auto &node : toposet) { | |||||
| if (IsPrimitiveCNode(node, prim::kPrimJ)) { | |||||
| j_nodes_.push_back(node->cast<CNodePtr>()); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } // namespace irpass | } // namespace irpass | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -31,28 +31,17 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace irpass { | namespace irpass { | ||||
| namespace internal { | |||||
| AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource); | |||||
| } // namespace internal | |||||
| // {prim::kPrimJ, C} | // {prim::kPrimJ, C} | ||||
| class ExpandJPrim : public AnfVisitor { | |||||
| class ExpandJPrim { | |||||
| public: | public: | ||||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||||
| x_ = nullptr; | |||||
| AnfVisitor::Match(prim::kPrimJ, {IsVNode})(node); | |||||
| if (x_ != nullptr) { | |||||
| TraceGuard guard(std::make_shared<TraceExpandJ>(node->debug_info())); | |||||
| auto j_node = internal::ExpandJ(x_, optimizer->resource()); | |||||
| return j_node; | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| void Visit(const ValueNodePtr &node) override { x_ = node; } | |||||
| ExpandJPrim() = default; | |||||
| virtual ~ExpandJPrim() = default; | |||||
| bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer); | |||||
| void GetJPrim(const FuncGraphManagerPtr &manager); | |||||
| private: | private: | ||||
| ValueNodePtr x_{nullptr}; | |||||
| std::vector<CNodePtr> j_nodes_; | |||||
| }; | }; | ||||
| } // namespace irpass | } // namespace irpass | ||||
| } // namespace opt | } // namespace opt | ||||
| @@ -41,6 +41,7 @@ | |||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "pipeline/jit/pipeline_split.h" | #include "pipeline/jit/pipeline_split.h" | ||||
| #include "pipeline/jit/static_analysis/auto_monad.h" | #include "pipeline/jit/static_analysis/auto_monad.h" | ||||
| #include "frontend/optimizer/irpass/gradient_eliminate.h" | |||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | ||||
| #include "ps/util.h" | #include "ps/util.h" | ||||
| #include "ps/ps_context.h" | #include "ps/ps_context.h" | ||||
| @@ -166,7 +167,6 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||||
| irpass.mini_step_allgather_replace_, | irpass.mini_step_allgather_replace_, | ||||
| }); | }); | ||||
| opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); | opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); | ||||
| opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true); | |||||
| opt::irpass::ResolveIRPassLib resolve_irpass; | opt::irpass::ResolveIRPassLib resolve_irpass; | ||||
| opt::OptPassConfig resolve_pass = | opt::OptPassConfig resolve_pass = | ||||
| @@ -180,7 +180,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||||
| {"parallel", opt::OptPassConfig(parallel::StepParallel)}, | {"parallel", opt::OptPassConfig(parallel::StepParallel)}, | ||||
| {"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)}, | {"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)}, | ||||
| {"virtual_dataset", virtual_dataset}, | {"virtual_dataset", virtual_dataset}, | ||||
| {"grad", grad}, | |||||
| {"grad", opt::OptPassConfig(opt::irpass::ExpandJPrim())}, | |||||
| {"resolve", resolve_pass}, | {"resolve", resolve_pass}, | ||||
| {"a_after_grad", a_after_grad}, | {"a_after_grad", a_after_grad}, | ||||
| {"renormalize", opt::OptPassConfig::Renormalize()}, | {"renormalize", opt::OptPassConfig::Renormalize()}, | ||||
| @@ -88,14 +88,6 @@ class TestOptLib : public UT::Common { | |||||
| irpass::OptimizeIRPassLib irpass; | irpass::OptimizeIRPassLib irpass; | ||||
| }; | }; | ||||
| TEST_F(TestOptLib, test_expendJ) { | |||||
| FuncGraphPtr before = getPyFun("test_expendJ"); | |||||
| ASSERT_TRUE(nullptr != before); | |||||
| FuncGraphPtr after = RunSubs(before, std::vector<SubstitutionPtr>({irpass.expand_jprim_})); | |||||
| } | |||||
| TEST_F(TestOptLib, test_simplify_always_true_false) { | TEST_F(TestOptLib, test_simplify_always_true_false) { | ||||
| FuncGraphPtr before1 = getPyFun.CallAndParseRet("test_simplify_always_true_false", "before_1"); | FuncGraphPtr before1 = getPyFun.CallAndParseRet("test_simplify_always_true_false", "before_1"); | ||||
| FuncGraphPtr before2 = getPyFun.CallAndParseRet("test_simplify_always_true_false", "before_2"); | FuncGraphPtr before2 = getPyFun.CallAndParseRet("test_simplify_always_true_false", "before_2"); | ||||
| @@ -24,6 +24,7 @@ | |||||
| #include "frontend/optimizer/cse_pass.h" | #include "frontend/optimizer/cse_pass.h" | ||||
| #include "frontend/optimizer/optimizer.h" | #include "frontend/optimizer/optimizer.h" | ||||
| #include "frontend/optimizer/irpass.h" | #include "frontend/optimizer/irpass.h" | ||||
| #include "frontend/optimizer/irpass/gradient_eliminate.h" | |||||
| #include "debug/draw.h" | #include "debug/draw.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -38,23 +39,24 @@ class TestOptOptimizer : public UT::Common { | |||||
| }; | }; | ||||
| TEST_F(TestOptOptimizer, test_step_opt) { | TEST_F(TestOptOptimizer, test_step_opt) { | ||||
| FuncGraphPtr before = getPyFun("test_expendJ"); | |||||
| FuncGraphPtr before = getPyFun("test_expandJ"); | |||||
| ASSERT_TRUE(nullptr != before); | ASSERT_TRUE(nullptr != before); | ||||
| pipeline::ResourcePtr res = std::make_shared<pipeline::Resource>(); | pipeline::ResourcePtr res = std::make_shared<pipeline::Resource>(); | ||||
| std::shared_ptr<Optimizer> optimizer = Optimizer::MakeOptimizer("ut_test", res, | |||||
| {{"main", | |||||
| { | |||||
| // Branch culling | |||||
| irpass.switch_simplify_, | |||||
| std::shared_ptr<Optimizer> optimizer = | |||||
| Optimizer::MakeOptimizer("ut_test", res, | |||||
| {{"main", | |||||
| { | |||||
| // Branch culling | |||||
| irpass.switch_simplify_, | |||||
| // Safe inlining | |||||
| irpass.arithmetic_simplify_, | |||||
| irpass.inline_, | |||||
| }}, | |||||
| {"grad", {irpass.expand_jprim_}}, | |||||
| {"cse", OptPassConfig(CSEPass(false))}}, | |||||
| true); | |||||
| // Safe inlining | |||||
| irpass.arithmetic_simplify_, | |||||
| irpass.inline_, | |||||
| }}, | |||||
| {"grad", opt::OptPassConfig(opt::irpass::ExpandJPrim())}, | |||||
| {"cse", OptPassConfig(CSEPass(false))}}, | |||||
| true); | |||||
| EXPECT_TRUE(optimizer.get() != nullptr); | EXPECT_TRUE(optimizer.get() != nullptr); | ||||
| auto after = optimizer->step(before); | auto after = optimizer->step(before); | ||||
| @@ -133,8 +133,8 @@ def cost(x): | |||||
| J = Primitive('J') | J = Primitive('J') | ||||
| def test_expendJ(x): | |||||
| """ test_expendJ """ | |||||
| def test_expandJ(x): | |||||
| """ test_expandJ """ | |||||
| return J(cost)(x) | return J(cost)(x) | ||||