From: @chenfei52 Reviewed-by: @zh_qh Signed-off-by: @zh_qhtags/v1.2.0-rc1
| @@ -123,8 +123,6 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| replace_refkey_by_param_ = MakeSubstitution(std::make_shared<ReplaceRefkeyByParam>(), "replace_refkey_by_param", | |||
| IsValueNode<RefKey>, opt::FORCE_RENORM); | |||
| replace_old_param_ = MakeSubstitution(std::make_shared<ReplaceOldParam>(), "replace_old_param", IsParam); | |||
| // Gradient transforms | |||
| expand_jprim_ = MakeSubstitution(std::make_shared<ExpandJPrim>(), "expand_jprim", prim::kPrimJ); | |||
| minmaximum_grad_ = MakeSubstitution(std::make_shared<MinMaximumGrad>(), "minmaximum_grad", prim::kPrimTupleGetItem); | |||
| // branch culling | |||
| @@ -85,7 +85,6 @@ class OptimizeIRPassLib { | |||
| SubstitutionPtr accumulaten_eliminater_; | |||
| // Gradient irpasses | |||
| SubstitutionPtr expand_jprim_; | |||
| SubstitutionPtr minmaximum_grad_; | |||
| // inline | |||
| @@ -22,25 +22,28 @@ namespace irpass { | |||
| namespace internal { | |||
| AnfNodePtr ExpandJPrimitive(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource) { | |||
| ScopeGuard scope_guard(vnode->scope()); | |||
| auto newg = ad::Kprim(vnode, resource); | |||
| if (newg != nullptr) { | |||
| return NewValueNode(newg); | |||
| } | |||
| // when find in J failed, try in Jmeta | |||
| auto prim = GetValueNode<PrimitivePtr>(vnode); | |||
| MetaFuncGraphPtr meta = ad::Kmeta(prim, resource); | |||
| if (meta != nullptr) { | |||
| return NewValueNode(meta); | |||
| } | |||
| return nullptr; | |||
| } | |||
| bool CheckIfEmbedJ(const FuncGraphPtr &func_graph) { | |||
| // if func graph also contain J(FuncGraph) or J(Primitive), then ignore this funcgraph. | |||
| // ExpandJ innermost graph first. | |||
| bool CheckIfEmbedJ(const CNodePtr &j_node) { | |||
| auto &value_node = j_node->input(1); | |||
| if (IsValueNode<Primitive>(value_node)) { | |||
| return false; | |||
| } | |||
| auto func_graph = GetValueNode<FuncGraphPtr>(value_node); | |||
| if (func_graph == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Unexpected j node:" << j_node->DebugString(); | |||
| } | |||
| auto func_graph_manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(func_graph_manager); | |||
| return func_graph_manager->func_graph_j_total(func_graph); | |||
| @@ -49,31 +52,48 @@ bool CheckIfEmbedJ(const FuncGraphPtr &func_graph) { | |||
| AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource) { | |||
| if (IsValueNode<FuncGraph>(vnode)) { | |||
| ScopeGuard scope_guard(vnode->scope()); | |||
| auto func_graph = GetValueNode<FuncGraphPtr>(vnode); | |||
| MS_LOG(DEBUG) << "Node is ValueNodeGraph, graph: " << func_graph->ToString(); | |||
| // high_order_grad begin; | |||
| // if graph also contains J(FuncGraph) or J(Primitive), then ignore this graph. | |||
| // ExpandJ innermost graph or primitive first. | |||
| if (CheckIfEmbedJ(func_graph)) { | |||
| MS_LOG(DEBUG) << "Funcgraph: " << func_graph->ToString() << " contains J, will expandJ later"; | |||
| return nullptr; | |||
| } | |||
| // high_order_grad end; | |||
| MS_LOG(DEBUG) << "Funcgraph: " << func_graph->ToString() << " will expandJ now"; | |||
| auto newfg = ad::Grad(func_graph, resource); | |||
| return NewValueNode(newfg); | |||
| } | |||
| if (IsValueNode<Primitive>(vnode)) { | |||
| return ExpandJPrimitive(vnode, resource); | |||
| } | |||
| return nullptr; | |||
| } | |||
| } // namespace internal | |||
| bool ExpandJPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) { | |||
| // Search all j nodes. | |||
| GetJPrim(optimizer->resource()->manager()); | |||
| // Get j nodes that don't have embed j nodes. | |||
| std::vector<CNodePtr> todo; | |||
| // If graph also contains J(FuncGraph) or J(Primitive), then ignore this graph. | |||
| // ExpandJ innermost graph or primitive first. | |||
| std::copy_if(j_nodes_.begin(), j_nodes_.end(), std::back_inserter(todo), | |||
| [](const CNodePtr &j_node) { return !internal::CheckIfEmbedJ(j_node); }); | |||
| // Expand j nodes that don't have embed j nodes. | |||
| bool change = false; | |||
| for (auto &j_node : todo) { | |||
| auto expanded_j = internal::ExpandJ(j_node->input(1)->cast<ValueNodePtr>(), optimizer->resource()); | |||
| optimizer->resource()->manager()->Replace(j_node, expanded_j); | |||
| change = true; | |||
| } | |||
| return change; | |||
| } | |||
| void ExpandJPrim::GetJPrim(const FuncGraphManagerPtr &manager) { | |||
| j_nodes_.clear(); | |||
| for (auto &fg : manager->func_graphs()) { | |||
| std::vector<AnfNodePtr> &&toposet = TopoSort(fg->get_return()); | |||
| for (const auto &node : toposet) { | |||
| if (IsPrimitiveCNode(node, prim::kPrimJ)) { | |||
| j_nodes_.push_back(node->cast<CNodePtr>()); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -31,28 +31,17 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace irpass { | |||
| namespace internal { | |||
| AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource); | |||
| } // namespace internal | |||
| // {prim::kPrimJ, C} | |||
| class ExpandJPrim : public AnfVisitor { | |||
| class ExpandJPrim { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||
| x_ = nullptr; | |||
| AnfVisitor::Match(prim::kPrimJ, {IsVNode})(node); | |||
| if (x_ != nullptr) { | |||
| TraceGuard guard(std::make_shared<TraceExpandJ>(node->debug_info())); | |||
| auto j_node = internal::ExpandJ(x_, optimizer->resource()); | |||
| return j_node; | |||
| } | |||
| return nullptr; | |||
| } | |||
| void Visit(const ValueNodePtr &node) override { x_ = node; } | |||
| ExpandJPrim() = default; | |||
| virtual ~ExpandJPrim() = default; | |||
| bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer); | |||
| void GetJPrim(const FuncGraphManagerPtr &manager); | |||
| private: | |||
| ValueNodePtr x_{nullptr}; | |||
| std::vector<CNodePtr> j_nodes_; | |||
| }; | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| @@ -41,6 +41,7 @@ | |||
| #include "utils/log_adapter.h" | |||
| #include "pipeline/jit/pipeline_split.h" | |||
| #include "pipeline/jit/static_analysis/auto_monad.h" | |||
| #include "frontend/optimizer/irpass/gradient_eliminate.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #include "ps/util.h" | |||
| #include "ps/ps_context.h" | |||
| @@ -166,7 +167,6 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| irpass.mini_step_allgather_replace_, | |||
| }); | |||
| opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); | |||
| opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true); | |||
| opt::irpass::ResolveIRPassLib resolve_irpass; | |||
| opt::OptPassConfig resolve_pass = | |||
| @@ -180,7 +180,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| {"parallel", opt::OptPassConfig(parallel::StepParallel)}, | |||
| {"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)}, | |||
| {"virtual_dataset", virtual_dataset}, | |||
| {"grad", grad}, | |||
| {"grad", opt::OptPassConfig(opt::irpass::ExpandJPrim())}, | |||
| {"resolve", resolve_pass}, | |||
| {"a_after_grad", a_after_grad}, | |||
| {"renormalize", opt::OptPassConfig::Renormalize()}, | |||
| @@ -88,14 +88,6 @@ class TestOptLib : public UT::Common { | |||
| irpass::OptimizeIRPassLib irpass; | |||
| }; | |||
| TEST_F(TestOptLib, test_expendJ) { | |||
| FuncGraphPtr before = getPyFun("test_expendJ"); | |||
| ASSERT_TRUE(nullptr != before); | |||
| FuncGraphPtr after = RunSubs(before, std::vector<SubstitutionPtr>({irpass.expand_jprim_})); | |||
| } | |||
| TEST_F(TestOptLib, test_simplify_always_true_false) { | |||
| FuncGraphPtr before1 = getPyFun.CallAndParseRet("test_simplify_always_true_false", "before_1"); | |||
| FuncGraphPtr before2 = getPyFun.CallAndParseRet("test_simplify_always_true_false", "before_2"); | |||
| @@ -24,6 +24,7 @@ | |||
| #include "frontend/optimizer/cse_pass.h" | |||
| #include "frontend/optimizer/optimizer.h" | |||
| #include "frontend/optimizer/irpass.h" | |||
| #include "frontend/optimizer/irpass/gradient_eliminate.h" | |||
| #include "debug/draw.h" | |||
| namespace mindspore { | |||
| @@ -38,23 +39,24 @@ class TestOptOptimizer : public UT::Common { | |||
| }; | |||
| TEST_F(TestOptOptimizer, test_step_opt) { | |||
| FuncGraphPtr before = getPyFun("test_expendJ"); | |||
| FuncGraphPtr before = getPyFun("test_expandJ"); | |||
| ASSERT_TRUE(nullptr != before); | |||
| pipeline::ResourcePtr res = std::make_shared<pipeline::Resource>(); | |||
| std::shared_ptr<Optimizer> optimizer = Optimizer::MakeOptimizer("ut_test", res, | |||
| {{"main", | |||
| { | |||
| // Branch culling | |||
| irpass.switch_simplify_, | |||
| std::shared_ptr<Optimizer> optimizer = | |||
| Optimizer::MakeOptimizer("ut_test", res, | |||
| {{"main", | |||
| { | |||
| // Branch culling | |||
| irpass.switch_simplify_, | |||
| // Safe inlining | |||
| irpass.arithmetic_simplify_, | |||
| irpass.inline_, | |||
| }}, | |||
| {"grad", {irpass.expand_jprim_}}, | |||
| {"cse", OptPassConfig(CSEPass(false))}}, | |||
| true); | |||
| // Safe inlining | |||
| irpass.arithmetic_simplify_, | |||
| irpass.inline_, | |||
| }}, | |||
| {"grad", opt::OptPassConfig(opt::irpass::ExpandJPrim())}, | |||
| {"cse", OptPassConfig(CSEPass(false))}}, | |||
| true); | |||
| EXPECT_TRUE(optimizer.get() != nullptr); | |||
| auto after = optimizer->step(before); | |||
| @@ -133,8 +133,8 @@ def cost(x): | |||
| J = Primitive('J') | |||
| def test_expendJ(x): | |||
| """ test_expendJ """ | |||
| def test_expandJ(x): | |||
| """ test_expandJ """ | |||
| return J(cost)(x) | |||