| @@ -17,6 +17,7 @@ | |||
| #include "frontend/optimizer/ad/grad.h" | |||
| #include "frontend/optimizer/ad/dfunctor.h" | |||
| #include "frontend/optimizer/irpass.h" | |||
| #include "frontend/optimizer/dead_node_eliminate.h" | |||
| #include "ir/func_graph_cloner.h" | |||
| #include "utils/ms_context.h" | |||
| #include "utils/symbolic.h" | |||
| @@ -24,7 +25,7 @@ | |||
| namespace mindspore { | |||
| namespace ad { | |||
| namespace { | |||
| FuncGraphPtr PartialEliminateOptPass(const ResourcePtr &resource, const FuncGraphPtr &func_graph) { | |||
| FuncGraphPtr PartialEliminateOptPass(const pipeline::ResourcePtr &resource, const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(resource); | |||
| opt::irpass::OptimizeIRPassLib irpass; | |||
| @@ -68,20 +69,32 @@ FuncGraphPtr LiftFv(const pipeline::ResourceBasePtr &resource, const FuncGraphPt | |||
| } | |||
| } // namespace | |||
| FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources, bool is_top) { | |||
| FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimzer, bool is_top) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto gradkv = func_graph->transforms().find("grad"); | |||
| if (gradkv != func_graph->transforms().end()) { | |||
| return gradkv->second.func_graph(); | |||
| } | |||
| const auto &resources = optimzer->resource(); | |||
| auto manager_ptr = resources->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager_ptr); | |||
| manager_ptr->AddFuncGraph(func_graph); | |||
| FuncGraphPtr grad_fg = func_graph; | |||
| if (func_graph->func_graphs_used().size() != 0) { | |||
| grad_fg = LiftFv(resources, func_graph); | |||
| static bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1"; | |||
| if (enable_closure) { | |||
| if (func_graph->func_graphs_used().size() != 0 && optimzer->is_first_order_j()) { | |||
| lift_fv_before_grad = true; | |||
| grad_fg = LiftFv(resources, func_graph); | |||
| } else { | |||
| lift_fv_before_grad = false; | |||
| opt::EliminateDeadNode(grad_fg); | |||
| } | |||
| } else { | |||
| if (func_graph->func_graphs_used().size() != 0) { | |||
| grad_fg = LiftFv(resources, func_graph); | |||
| } | |||
| } | |||
| auto multi_graph_sink = [&func_graph](const FuncGraphPtr &f) { | |||
| if (MsContext::GetInstance()->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK)) { | |||
| @@ -22,13 +22,11 @@ | |||
| #include "ir/anf.h" | |||
| #include "ir/meta_func_graph.h" | |||
| #include "pipeline/jit/resource.h" | |||
| #include "frontend/optimizer/optimizer.h" | |||
| namespace mindspore { | |||
| namespace ad { | |||
| using ResourcePtr = std::shared_ptr<pipeline::Resource>; | |||
| FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources, bool is_top = true); | |||
| FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer, bool is_top = true); | |||
| FuncGraphPtr Kprim(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); | |||
| MetaFuncGraphPtr Kmeta(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &); | |||
| void CleanRes(); | |||
| @@ -83,7 +83,8 @@ void CheckSwitchWithSideEffect(const FuncGraphPtr &fg) { | |||
| } | |||
| } | |||
| AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource) { | |||
| AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const OptimizerPtr &optimizer) { | |||
| AnfNodePtr expanded_node = nullptr; | |||
| if (IsValueNode<FuncGraph>(vnode)) { | |||
| ScopeGuard scope_guard(vnode->scope()); | |||
| auto func_graph = GetValueNode<FuncGraphPtr>(vnode); | |||
| @@ -92,13 +93,15 @@ AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &r | |||
| CheckSwitchWithSideEffect(func_graph); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_LOG(DEBUG) << "Funcgraph: " << func_graph->ToString() << " will expandJ now"; | |||
| auto newfg = ad::Grad(func_graph, resource); | |||
| return NewValueNode(newfg); | |||
| auto newfg = ad::Grad(func_graph, optimizer); | |||
| expanded_node = NewValueNode(newfg); | |||
| } else if (IsValueNode<Primitive>(vnode)) { | |||
| expanded_node = ExpandJPrimitive(vnode, optimizer->resource()); | |||
| } else { | |||
| return nullptr; | |||
| } | |||
| if (IsValueNode<Primitive>(vnode)) { | |||
| return ExpandJPrimitive(vnode, resource); | |||
| } | |||
| return nullptr; | |||
| optimizer->set_is_first_order_j(false); | |||
| return expanded_node; | |||
| } | |||
| } // namespace internal | |||
| @@ -122,7 +125,7 @@ bool ExpandJPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr | |||
| bool change = false; | |||
| auto manager = optimizer->manager(); | |||
| for (auto &j_node : todo) { | |||
| auto expanded_j = internal::ExpandJ(j_node->input(1)->cast<ValueNodePtr>(), optimizer->resource()); | |||
| auto expanded_j = internal::ExpandJ(j_node->input(1)->cast<ValueNodePtr>(), optimizer); | |||
| manager->Replace(j_node, expanded_j); | |||
| change = true; | |||
| } | |||
| @@ -142,6 +142,12 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||
| return optimizer; | |||
| } | |||
| static std::shared_ptr<Optimizer> MakeEmptyOptimizer(const pipeline::ResourceBasePtr resource_ptr) { | |||
| OptimizerPtr optimizer = std::make_shared<Optimizer>("empty", resource_ptr, false); | |||
| optimizer->Init(OptPassGroupMap{}, false); | |||
| return optimizer; | |||
| } | |||
| FuncGraphPtr step(FuncGraphPtr func_graph, bool use_profile = true) { | |||
| if (!is_enable_) { | |||
| return func_graph; | |||
| @@ -240,6 +246,9 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||
| bool traverse_nodes_first() { return traverse_nodes_first_; } | |||
| bool is_first_order_j() { return is_first_order_j_; } | |||
| void set_is_first_order_j(bool is_first_order_j) { is_first_order_j_ = is_first_order_j; } | |||
| struct { | |||
| int64_t counter; | |||
| std::string name; | |||
| @@ -257,6 +266,8 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||
| bool is_enable_; | |||
| bool is_untyped_generated_; | |||
| bool traverse_nodes_first_; | |||
| // A flag to indicate if it's the first order J or innermost J in GraphMode. | |||
| bool is_first_order_j_{true}; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -725,7 +725,7 @@ bool EliminateForwardCNode(const ResourcePtr &res) { | |||
| auto grad_exec = pynative_exec->grad_executor(); | |||
| bool eliminate_forward = grad_exec->eliminate_forward(); | |||
| grad_exec->set_eliminate_forward(eliminate_forward && ms_func_graph->func_graphs_used().empty()); | |||
| auto grad_graph = ad::Grad(ms_func_graph, res); | |||
| auto grad_graph = ad::Grad(ms_func_graph, opt::Optimizer::MakeEmptyOptimizer(res)); | |||
| MS_EXCEPTION_IF_NULL(grad_graph); | |||
| graph_executor->SetGradGraph(grad_graph, phase); | |||
| ModifyOutputNode(ms_func_graph); | |||
| @@ -3167,7 +3167,7 @@ void GradExecutor::MakeNestedCnode(const py::object &cell, const py::tuple &forw | |||
| r->manager()->AddFuncGraph(first_grad_fg); | |||
| set_eliminate_forward(false); | |||
| first_grad_fg->transforms().erase(kGrad); | |||
| FuncGraphPtr second_grad_fg = ad::Grad(first_grad_fg, r); | |||
| FuncGraphPtr second_grad_fg = ad::Grad(first_grad_fg, opt::Optimizer::MakeEmptyOptimizer(r)); | |||
| set_eliminate_forward(true); | |||
| DumpGraphIR("second_grad_fg.ir", second_grad_fg); | |||
| r->Clean(); | |||
| @@ -28,6 +28,7 @@ | |||
| #include "pipeline/jit/parse/parse.h" | |||
| #include "debug/draw.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "frontend/optimizer/optimizer.h" | |||
| namespace mindspore { | |||
| namespace ad { | |||
| @@ -44,7 +45,7 @@ class TestAD : public UT::Common { | |||
| FuncGraphPtr g = getPyFun(testCase); | |||
| resourcePtr->manager()->RemoveRoots(); | |||
| resourcePtr->manager()->AddFuncGraph(g, true); | |||
| FuncGraphPtr dg = Grad(g, resourcePtr); | |||
| FuncGraphPtr dg = Grad(g, opt::Optimizer::MakeEmptyOptimizer(resourcePtr)); | |||
| AssertExpect(testCase, dg); | |||
| } | |||
| @@ -188,8 +189,8 @@ TEST_F(TestAD, test_prim_switch) { | |||
| TEST_F(TestAD, test_grad_cache) { | |||
| FuncGraphPtr g = getPyFun("test_null"); | |||
| FuncGraphPtr dg1 = Grad(g, resourcePtr); | |||
| FuncGraphPtr dg2 = Grad(g, resourcePtr); | |||
| FuncGraphPtr dg1 = Grad(g, opt::Optimizer::MakeEmptyOptimizer(resourcePtr)); | |||
| FuncGraphPtr dg2 = Grad(g, opt::Optimizer::MakeEmptyOptimizer(resourcePtr)); | |||
| ASSERT_TRUE(dg1 == dg2); | |||
| } | |||