| @@ -234,8 +234,7 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { | |||
| AdjointPtr node_adjoint = nullptr; | |||
| AnfNodePtr k = nullptr; | |||
| if (IsValueNode<Primitive>(node)) { | |||
| TraceGuard trace_guard(std::make_shared<TraceEquiv>(cnode_morph->debug_info())); | |||
| k = MapToK(node); | |||
| k = MapToK(cnode_morph, i); | |||
| node_adjoint = std::make_shared<Adjoint>(node, k, tape_); | |||
| anfnode_to_adjoin_[node] = node_adjoint; | |||
| } else { | |||
| @@ -597,6 +596,31 @@ AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) { | |||
| return NewValueNode(functor->k_graph_); | |||
| } | |||
| // Construct representation graph for primitive CNode. | |||
| AnfNodePtr DFunctor::MapToK(const CNodePtr &primal_user, size_t index) { | |||
| auto primal = primal_user->input(index); | |||
| ScopeGuard scope_guard(primal->scope()); | |||
| // Map primitive to K | |||
| if (IsValueNode<Primitive>(primal)) { | |||
| auto value_node = primal->cast<ValueNodePtr>(); | |||
| auto prim = GetValueNode<PrimitivePtr>(value_node); | |||
| if (prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) { | |||
| MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << "."; | |||
| need_cut_ = true; | |||
| } | |||
| auto k_prim = g_k_prims.KPrimitive(primal_user, value_node, resources_); | |||
| if (k_prim != nullptr) { | |||
| return NewValueNode(k_prim); | |||
| } | |||
| // When failed to find k_prim, try k_meta. | |||
| auto k_meta = g_k_prims.KMetaFuncGraph(prim); | |||
| if (k_meta != nullptr) { | |||
| return NewValueNode(k_meta); | |||
| } | |||
| } | |||
| return MapToK(primal); | |||
| } | |||
| // Construct representation graph for given node. | |||
| AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) { | |||
| ScopeGuard scope_guard(primal->scope()); | |||
| @@ -608,7 +632,7 @@ AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) { | |||
| MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << "."; | |||
| need_cut_ = true; | |||
| } | |||
| auto k_prim = g_k_prims.KPrimitive(value_node, resources_); | |||
| auto k_prim = g_k_prims.KPrimitive(nullptr, value_node, resources_); | |||
| if (k_prim != nullptr) { | |||
| return NewValueNode(k_prim); | |||
| } | |||
| @@ -81,8 +81,10 @@ class DFunctor : public std::enable_shared_from_this<DFunctor> { | |||
| void BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint); | |||
| AnfNodePtr AttachFvDoutToTape(const AnfNodePtr &grad_fv); | |||
| AnfNodePtr AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv); | |||
| // Map Anfnode object from D category to K category. | |||
| // Map AnfNode object from D category to K category. | |||
| AnfNodePtr MapToK(const AnfNodePtr &primal); | |||
| // Map CNode object from D category to K category. | |||
| AnfNodePtr MapToK(const CNodePtr &primal_user, size_t index); | |||
| // Map FuncGraph object from D category to K category. | |||
| AnfNodePtr MapToK(const FuncGraphPtr &primal); | |||
| // MapObject impls. | |||
| @@ -129,7 +131,8 @@ class KPrim { | |||
| KPrim() = default; | |||
| ~KPrim() = default; | |||
| FuncGraphPtr KPrimitive(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); | |||
| FuncGraphPtr KPrimitive(const CNodePtr &primal_user, const ValueNodePtr &value_node, | |||
| const pipeline::ResourceBasePtr &resources); | |||
| MetaFuncGraphPtr KMetaFuncGraph(const PrimitivePtr &prim); | |||
| FuncGraphPtr KUserDefinedCellBprop(FuncGraphPtr bprop); | |||
| @@ -145,7 +148,7 @@ class KPrim { | |||
| FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); | |||
| // Given a bprop rule, do the K mapping. | |||
| template <typename T> | |||
| FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_g); | |||
| FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_g, const CNodePtr &cnode); | |||
| AnfNodePtr BuildOutput(const FuncGraphPtr &bprop_fg); | |||
| void TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer, | |||
| std::vector<AnfNodePtr> *const transf_args); | |||
| @@ -156,7 +159,7 @@ class KPrim { | |||
| }; | |||
| template <typename T> | |||
| FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) { | |||
| FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg, const CNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(primal); | |||
| MS_EXCEPTION_IF_NULL(bprop_fg); | |||
| CheckBprop(bprop_fg, primal->ToString()); | |||
| @@ -197,8 +200,13 @@ FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) { | |||
| TransformArgs(mng, cloned_bprop_fg, outer, &transf_args); | |||
| (void)transf_args.insert(transf_args.begin(), NewValueNode(primal)); | |||
| auto out_value = outer->NewCNode(transf_args); | |||
| CNodePtr out_value = nullptr; | |||
| if (cnode != nullptr) { // Set equiv debug info. for Primitive CNode out. | |||
| TraceGuard trace_guard(std::make_shared<TraceEquiv>(cnode->debug_info())); | |||
| out_value = outer->NewCNode(transf_args); | |||
| } else { | |||
| out_value = outer->NewCNode(transf_args); | |||
| } | |||
| (void)mng->Replace(out_param, out_value); | |||
| TraceGuard guard(std::make_shared<TraceGradSens>(out_param->debug_info())); | |||
| @@ -207,7 +215,6 @@ FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) { | |||
| // We remove all parameters except new_dout. | |||
| std::vector<AnfNodePtr> newBpropParams = {new_dout}; | |||
| cloned_bprop_fg->set_parameters(newBpropParams); | |||
| outer->set_output(outer->NewCNode({NewValueNode(prim::kPrimMakeTuple), out_value, NewValueNode(cloned_bprop_fg)})); | |||
| return BasicClone(outer); | |||
| } | |||
| @@ -64,7 +64,7 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt | |||
| } | |||
| FuncGraphPtr Kprim(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { | |||
| auto fg = g_k_prims.KPrimitive(value_node, resources); | |||
| auto fg = g_k_prims.KPrimitive(nullptr, value_node, resources); | |||
| if (fg == nullptr) { | |||
| return nullptr; | |||
| } | |||
| @@ -102,7 +102,8 @@ MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) { | |||
| MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << "."; | |||
| } | |||
| FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { | |||
| FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_node, | |||
| const pipeline::ResourceBasePtr &resources) { | |||
| if (!IsValueNode<Primitive>(value_node)) { | |||
| MS_LOG(EXCEPTION) << "Primitive node is not valid."; | |||
| } | |||
| @@ -141,7 +142,7 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R | |||
| } | |||
| } | |||
| auto expanded_fg = BpropToK(prim, bprop_fg); | |||
| auto expanded_fg = BpropToK(prim, bprop_fg, cnode); | |||
| if (expanded_fg == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Failed convert " << prim->name() | |||
| << " prim bprop function to J expanded func graph. NodeInfo: " | |||
| @@ -220,7 +221,7 @@ void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check | |||
| FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr bprop_fg) { | |||
| MS_EXCEPTION_IF_NULL(bprop_fg); | |||
| auto fprop_fg = bprop_fg->transforms().find("primal")->second.func_graph(); | |||
| auto expanded_fg = BpropToK(fprop_fg, bprop_fg); | |||
| auto expanded_fg = BpropToK(fprop_fg, bprop_fg, nullptr); | |||
| if (expanded_fg == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Failed convert " << fprop_fg->ToString() | |||
| << " Cell bprop function to K expanded func graph. NodeInfo: " | |||
| @@ -33,11 +33,11 @@ class TestGradImplementations : public UT::Common { | |||
| }; | |||
| TEST_F(TestGradImplementations, TestGetAugmentedGraph) { | |||
| FuncGraphPtr fg = ad::g_k_prims.KPrimitive(NewValueNode(kPrimScalarMul), nullptr); | |||
| FuncGraphPtr fg = ad::g_k_prims.KPrimitive(nullptr, NewValueNode(kPrimScalarMul), nullptr); | |||
| ASSERT_TRUE(fg != nullptr); | |||
| draw::Draw("gradImpl_TestGetAugmentedFuncGraph.dot", fg); | |||
| auto fg1 = ad::g_k_prims.KPrimitive(NewValueNode(kPrimScalarMul), nullptr); | |||
| auto fg1 = ad::g_k_prims.KPrimitive(nullptr, NewValueNode(kPrimScalarMul), nullptr); | |||
| FuncGraphPairMapEquiv equiv_graph; | |||
| NodeMapEquiv equiv_node; | |||