| @@ -234,8 +234,7 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { | |||||
| AdjointPtr node_adjoint = nullptr; | AdjointPtr node_adjoint = nullptr; | ||||
| AnfNodePtr k = nullptr; | AnfNodePtr k = nullptr; | ||||
| if (IsValueNode<Primitive>(node)) { | if (IsValueNode<Primitive>(node)) { | ||||
| TraceGuard trace_guard(std::make_shared<TraceEquiv>(cnode_morph->debug_info())); | |||||
| k = MapToK(node); | |||||
| k = MapToK(cnode_morph, i); | |||||
| node_adjoint = std::make_shared<Adjoint>(node, k, tape_); | node_adjoint = std::make_shared<Adjoint>(node, k, tape_); | ||||
| anfnode_to_adjoin_[node] = node_adjoint; | anfnode_to_adjoin_[node] = node_adjoint; | ||||
| } else { | } else { | ||||
| @@ -597,6 +596,31 @@ AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) { | |||||
| return NewValueNode(functor->k_graph_); | return NewValueNode(functor->k_graph_); | ||||
| } | } | ||||
| // Construct representation graph for primitive CNode. | |||||
| AnfNodePtr DFunctor::MapToK(const CNodePtr &primal_user, size_t index) { | |||||
| auto primal = primal_user->input(index); | |||||
| ScopeGuard scope_guard(primal->scope()); | |||||
| // Map primitive to K | |||||
| if (IsValueNode<Primitive>(primal)) { | |||||
| auto value_node = primal->cast<ValueNodePtr>(); | |||||
| auto prim = GetValueNode<PrimitivePtr>(value_node); | |||||
| if (prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) { | |||||
| MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << "."; | |||||
| need_cut_ = true; | |||||
| } | |||||
| auto k_prim = g_k_prims.KPrimitive(primal_user, value_node, resources_); | |||||
| if (k_prim != nullptr) { | |||||
| return NewValueNode(k_prim); | |||||
| } | |||||
| // When failed to find k_prim, try k_meta. | |||||
| auto k_meta = g_k_prims.KMetaFuncGraph(prim); | |||||
| if (k_meta != nullptr) { | |||||
| return NewValueNode(k_meta); | |||||
| } | |||||
| } | |||||
| return MapToK(primal); | |||||
| } | |||||
| // Construct representation graph for given node. | // Construct representation graph for given node. | ||||
| AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) { | AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) { | ||||
| ScopeGuard scope_guard(primal->scope()); | ScopeGuard scope_guard(primal->scope()); | ||||
| @@ -608,7 +632,7 @@ AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) { | |||||
| MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << "."; | MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << "."; | ||||
| need_cut_ = true; | need_cut_ = true; | ||||
| } | } | ||||
| auto k_prim = g_k_prims.KPrimitive(value_node, resources_); | |||||
| auto k_prim = g_k_prims.KPrimitive(nullptr, value_node, resources_); | |||||
| if (k_prim != nullptr) { | if (k_prim != nullptr) { | ||||
| return NewValueNode(k_prim); | return NewValueNode(k_prim); | ||||
| } | } | ||||
| @@ -81,8 +81,10 @@ class DFunctor : public std::enable_shared_from_this<DFunctor> { | |||||
| void BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint); | void BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint); | ||||
| AnfNodePtr AttachFvDoutToTape(const AnfNodePtr &grad_fv); | AnfNodePtr AttachFvDoutToTape(const AnfNodePtr &grad_fv); | ||||
| AnfNodePtr AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv); | AnfNodePtr AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv); | ||||
| // Map Anfnode object from D category to K category. | |||||
| // Map AnfNode object from D category to K category. | |||||
| AnfNodePtr MapToK(const AnfNodePtr &primal); | AnfNodePtr MapToK(const AnfNodePtr &primal); | ||||
| // Map CNode object from D category to K category. | |||||
| AnfNodePtr MapToK(const CNodePtr &primal_user, size_t index); | |||||
| // Map FuncGraph object from D category to K category. | // Map FuncGraph object from D category to K category. | ||||
| AnfNodePtr MapToK(const FuncGraphPtr &primal); | AnfNodePtr MapToK(const FuncGraphPtr &primal); | ||||
| // MapObject impls. | // MapObject impls. | ||||
| @@ -129,7 +131,8 @@ class KPrim { | |||||
| KPrim() = default; | KPrim() = default; | ||||
| ~KPrim() = default; | ~KPrim() = default; | ||||
| FuncGraphPtr KPrimitive(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); | |||||
| FuncGraphPtr KPrimitive(const CNodePtr &primal_user, const ValueNodePtr &value_node, | |||||
| const pipeline::ResourceBasePtr &resources); | |||||
| MetaFuncGraphPtr KMetaFuncGraph(const PrimitivePtr &prim); | MetaFuncGraphPtr KMetaFuncGraph(const PrimitivePtr &prim); | ||||
| FuncGraphPtr KUserDefinedCellBprop(FuncGraphPtr bprop); | FuncGraphPtr KUserDefinedCellBprop(FuncGraphPtr bprop); | ||||
| @@ -145,7 +148,7 @@ class KPrim { | |||||
| FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); | FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); | ||||
| // Given a bprop rule, do the K mapping. | // Given a bprop rule, do the K mapping. | ||||
| template <typename T> | template <typename T> | ||||
| FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_g); | |||||
| FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_g, const CNodePtr &cnode); | |||||
| AnfNodePtr BuildOutput(const FuncGraphPtr &bprop_fg); | AnfNodePtr BuildOutput(const FuncGraphPtr &bprop_fg); | ||||
| void TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer, | void TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer, | ||||
| std::vector<AnfNodePtr> *const transf_args); | std::vector<AnfNodePtr> *const transf_args); | ||||
| @@ -156,7 +159,7 @@ class KPrim { | |||||
| }; | }; | ||||
| template <typename T> | template <typename T> | ||||
| FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) { | |||||
| FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg, const CNodePtr &cnode) { | |||||
| MS_EXCEPTION_IF_NULL(primal); | MS_EXCEPTION_IF_NULL(primal); | ||||
| MS_EXCEPTION_IF_NULL(bprop_fg); | MS_EXCEPTION_IF_NULL(bprop_fg); | ||||
| CheckBprop(bprop_fg, primal->ToString()); | CheckBprop(bprop_fg, primal->ToString()); | ||||
| @@ -197,8 +200,13 @@ FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) { | |||||
| TransformArgs(mng, cloned_bprop_fg, outer, &transf_args); | TransformArgs(mng, cloned_bprop_fg, outer, &transf_args); | ||||
| (void)transf_args.insert(transf_args.begin(), NewValueNode(primal)); | (void)transf_args.insert(transf_args.begin(), NewValueNode(primal)); | ||||
| auto out_value = outer->NewCNode(transf_args); | |||||
| CNodePtr out_value = nullptr; | |||||
| if (cnode != nullptr) { // Set equiv debug info. for Primitive CNode out. | |||||
| TraceGuard trace_guard(std::make_shared<TraceEquiv>(cnode->debug_info())); | |||||
| out_value = outer->NewCNode(transf_args); | |||||
| } else { | |||||
| out_value = outer->NewCNode(transf_args); | |||||
| } | |||||
| (void)mng->Replace(out_param, out_value); | (void)mng->Replace(out_param, out_value); | ||||
| TraceGuard guard(std::make_shared<TraceGradSens>(out_param->debug_info())); | TraceGuard guard(std::make_shared<TraceGradSens>(out_param->debug_info())); | ||||
| @@ -207,7 +215,6 @@ FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) { | |||||
| // We remove all parameters except new_dout. | // We remove all parameters except new_dout. | ||||
| std::vector<AnfNodePtr> newBpropParams = {new_dout}; | std::vector<AnfNodePtr> newBpropParams = {new_dout}; | ||||
| cloned_bprop_fg->set_parameters(newBpropParams); | cloned_bprop_fg->set_parameters(newBpropParams); | ||||
| outer->set_output(outer->NewCNode({NewValueNode(prim::kPrimMakeTuple), out_value, NewValueNode(cloned_bprop_fg)})); | outer->set_output(outer->NewCNode({NewValueNode(prim::kPrimMakeTuple), out_value, NewValueNode(cloned_bprop_fg)})); | ||||
| return BasicClone(outer); | return BasicClone(outer); | ||||
| } | } | ||||
| @@ -64,7 +64,7 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt | |||||
| } | } | ||||
| FuncGraphPtr Kprim(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { | FuncGraphPtr Kprim(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { | ||||
| auto fg = g_k_prims.KPrimitive(value_node, resources); | |||||
| auto fg = g_k_prims.KPrimitive(nullptr, value_node, resources); | |||||
| if (fg == nullptr) { | if (fg == nullptr) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -102,7 +102,8 @@ MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) { | |||||
| MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << "."; | MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << "."; | ||||
| } | } | ||||
| FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { | |||||
| FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_node, | |||||
| const pipeline::ResourceBasePtr &resources) { | |||||
| if (!IsValueNode<Primitive>(value_node)) { | if (!IsValueNode<Primitive>(value_node)) { | ||||
| MS_LOG(EXCEPTION) << "Primitive node is not valid."; | MS_LOG(EXCEPTION) << "Primitive node is not valid."; | ||||
| } | } | ||||
| @@ -141,7 +142,7 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R | |||||
| } | } | ||||
| } | } | ||||
| auto expanded_fg = BpropToK(prim, bprop_fg); | |||||
| auto expanded_fg = BpropToK(prim, bprop_fg, cnode); | |||||
| if (expanded_fg == nullptr) { | if (expanded_fg == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "Failed convert " << prim->name() | MS_LOG(EXCEPTION) << "Failed convert " << prim->name() | ||||
| << " prim bprop function to J expanded func graph. NodeInfo: " | << " prim bprop function to J expanded func graph. NodeInfo: " | ||||
| @@ -220,7 +221,7 @@ void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check | |||||
| FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr bprop_fg) { | FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr bprop_fg) { | ||||
| MS_EXCEPTION_IF_NULL(bprop_fg); | MS_EXCEPTION_IF_NULL(bprop_fg); | ||||
| auto fprop_fg = bprop_fg->transforms().find("primal")->second.func_graph(); | auto fprop_fg = bprop_fg->transforms().find("primal")->second.func_graph(); | ||||
| auto expanded_fg = BpropToK(fprop_fg, bprop_fg); | |||||
| auto expanded_fg = BpropToK(fprop_fg, bprop_fg, nullptr); | |||||
| if (expanded_fg == nullptr) { | if (expanded_fg == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "Failed convert " << fprop_fg->ToString() | MS_LOG(EXCEPTION) << "Failed convert " << fprop_fg->ToString() | ||||
| << " Cell bprop function to K expanded func graph. NodeInfo: " | << " Cell bprop function to K expanded func graph. NodeInfo: " | ||||
| @@ -33,11 +33,11 @@ class TestGradImplementations : public UT::Common { | |||||
| }; | }; | ||||
| TEST_F(TestGradImplementations, TestGetAugmentedGraph) { | TEST_F(TestGradImplementations, TestGetAugmentedGraph) { | ||||
| FuncGraphPtr fg = ad::g_k_prims.KPrimitive(NewValueNode(kPrimScalarMul), nullptr); | |||||
| FuncGraphPtr fg = ad::g_k_prims.KPrimitive(nullptr, NewValueNode(kPrimScalarMul), nullptr); | |||||
| ASSERT_TRUE(fg != nullptr); | ASSERT_TRUE(fg != nullptr); | ||||
| draw::Draw("gradImpl_TestGetAugmentedFuncGraph.dot", fg); | draw::Draw("gradImpl_TestGetAugmentedFuncGraph.dot", fg); | ||||
| auto fg1 = ad::g_k_prims.KPrimitive(NewValueNode(kPrimScalarMul), nullptr); | |||||
| auto fg1 = ad::g_k_prims.KPrimitive(nullptr, NewValueNode(kPrimScalarMul), nullptr); | |||||
| FuncGraphPairMapEquiv equiv_graph; | FuncGraphPairMapEquiv equiv_graph; | ||||
| NodeMapEquiv equiv_node; | NodeMapEquiv equiv_node; | ||||