diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc index 8e945b47bd..567691af98 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc @@ -232,22 +232,15 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { for (size_t i = 0; i < cnode_morph->size(); i++) { auto node = cnode_morph->input(i); AdjointPtr node_adjoint = nullptr; - AnfNodePtr k = nullptr; - if (IsValueNode(node)) { - k = MapToK(cnode_morph, i); - node_adjoint = std::make_shared(node, k, tape_); - anfnode_to_adjoin_[node] = node_adjoint; + auto node_adjoint_iter = anfnode_to_adjoin_.find(node); + if (node_adjoint_iter != anfnode_to_adjoin_.end()) { + node_adjoint = node_adjoint_iter->second; } else { - auto node_adjoint_iter = anfnode_to_adjoin_.find(node); - if (node_adjoint_iter != anfnode_to_adjoin_.end()) { - node_adjoint = node_adjoint_iter->second; - } else { - // Input might be a CNode that needs to be handled previously. - node_adjoint = MapMorphism(node); - } - MS_EXCEPTION_IF_NULL(node_adjoint); - k = node_adjoint->k(); + // Input might be a CNode that needs to be handled previously. + node_adjoint = MapMorphism(node); } + MS_EXCEPTION_IF_NULL(node_adjoint); + AnfNodePtr k = node_adjoint->k(); if (k == nullptr) { MS_LOG(EXCEPTION) << "MapMorphism adjoint node does not exist, input[" << i << "] " << node->ToString() << "."; } @@ -537,93 +530,69 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) { return nullptr; } -// Map func graph to K -AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) { - auto f = func_graph_to_functor_.find(primal); +// Construct representation graph for {CNode, Index} of Primitive. +AnfNodePtr DFunctor::MapPrimitiveToK(const CNodePtr &primitive_user, size_t index) { + auto primal = primitive_user->input(index); + if (!IsValueNode(primal)) { + MS_LOG(EXCEPTION) << "Primal graph \"" << primal->ToString() << "\" is not a ValueNode of Primitive."; + } + ScopeGuard scope_guard(primal->scope()); + // Map Primitive to K + auto value_node = primal->cast(); + auto prim = GetValueNode(value_node); + if (prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) { + MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << "."; + need_cut_ = true; + } + auto k_prim = g_k_prims.KPrimitive(primitive_user, value_node, resources_); + if (k_prim != nullptr) { + return NewValueNode(k_prim); + } + // When failed to find k_prim, try k_meta. + auto k_meta = g_k_prims.KMetaFuncGraph(prim); + if (k_meta != nullptr) { + return NewValueNode(k_meta); + } + MS_LOG(EXCEPTION) << "Fail to map Primitive of \"" << primal->ToString() << "\" to K."; +} + +// Construct representation graph for ValueNode of FuncGraph. +AnfNodePtr DFunctor::MapFuncGraphToK(const AnfNodePtr &primal) { + if (!IsValueNode(primal)) { + MS_LOG(EXCEPTION) << "Primal graph \"" << primal->ToString() << "\" is not a ValueNode of FuncGraph."; + } + ScopeGuard scope_guard(primal->scope()); + // Map func graph to K + auto func_graph = GetValueNode(primal); + auto f = func_graph_to_functor_.find(func_graph); if (f != func_graph_to_functor_.end()) { - MS_LOG(DEBUG) << "K graph functor already exist " << primal->ToString() << "."; + MS_LOG(DEBUG) << "K graph functor already exist " << func_graph->ToString() << "."; return NewValueNode(f->second->k_graph_); } - - auto k_user_defined = KUserDefined(primal); + auto k_user_defined = KUserDefined(func_graph); if (k_user_defined != nullptr) { - MS_LOG(DEBUG) << "K graph functor user defined bprop " << primal->ToString() << "."; + MS_LOG(DEBUG) << "K graph functor user defined bprop " << func_graph->ToString() << "."; return NewValueNode(k_user_defined); } - - auto functor = std::make_shared(primal, resources_); + auto functor = std::make_shared(func_graph, resources_); functor->Init(); functor->MapObject(); functor->MapMorphism(); - MS_LOG(DEBUG) << "K graph K function graph " << primal->ToString() << " " << functor->k_graph_->ToString() << "."; + MS_LOG(DEBUG) << "Map \"" << func_graph->ToString() << "\" to \"" << functor->k_graph_->ToString() << "\""; return NewValueNode(functor->k_graph_); } -// Construct representation graph for primitive CNode. -AnfNodePtr DFunctor::MapToK(const CNodePtr &primal_user, size_t index) { - auto primal = primal_user->input(index); - ScopeGuard scope_guard(primal->scope()); - // Map primitive to K - if (IsValueNode(primal)) { - auto value_node = primal->cast(); - auto prim = GetValueNode(value_node); - if (prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) { - MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << "."; - need_cut_ = true; - } - auto k_prim = g_k_prims.KPrimitive(primal_user, value_node, resources_); - if (k_prim != nullptr) { - return NewValueNode(k_prim); - } - // When failed to find k_prim, try k_meta. - auto k_meta = g_k_prims.KMetaFuncGraph(prim); - if (k_meta != nullptr) { - return NewValueNode(k_meta); - } +// Construct for ValueNode of Parameter. +AnfNodePtr DFunctor::MapParameterToK(const AnfNodePtr &primal) { + if (!primal->isa()) { + MS_LOG(EXCEPTION) << "Primal graph \"" << primal->ToString() << "\" is not a ValueNode of Parameter."; } - return MapToK(primal); -} - -// Construct representation graph for given node. -AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) { ScopeGuard scope_guard(primal->scope()); - // Map primitive to K - if (IsValueNode(primal)) { - auto value_node = primal->cast(); - auto prim = GetValueNode(value_node); - if (prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) { - MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << "."; - need_cut_ = true; - } - auto k_prim = g_k_prims.KPrimitive(nullptr, value_node, resources_); - if (k_prim != nullptr) { - return NewValueNode(k_prim); - } - // When failed to find k_prim, try k_meta. - auto k_meta = g_k_prims.KMetaFuncGraph(prim); - if (k_meta != nullptr) { - return NewValueNode(k_meta); - } - } - - // Map func graph to K - if (IsValueNode(primal)) { - auto func_graph = GetValueNode(primal); - auto k_func = MapToK(func_graph); - return k_func; - } - - if (primal->isa()) { - TraceGuard trace_guard(std::make_shared(primal->debug_info())); - auto ret = k_graph_->add_parameter(); - return ret; - } - - if (!primal->isa()) { - MS_LOG(EXCEPTION) << "K node keeped node from primal_graph_ " << primal->ToString() << " that is not a ValueNode."; - } - return primal; + // Map Parameter to K + TraceGuard trace_guard(std::make_shared(primal->debug_info())); + auto ret = k_graph_->add_parameter(); + return ret; } bool DFunctor::IsInScope(const AnfNodePtr &node) { @@ -664,7 +633,7 @@ void DFunctor::MapParamObject() { for (auto &p : primal_graph_->parameters()) { ScopeGuard scope_guard(p->scope()); MS_LOG(DEBUG) << "MapParamObject parameter " << p->ToString() << "."; - auto adjoint = std::make_shared(p, MapToK(p), tape_); + auto adjoint = std::make_shared(p, MapParameterToK(p), tape_); UpdateAdjoint(adjoint); anfnode_to_adjoin_[p] = adjoint; } @@ -682,12 +651,32 @@ void DFunctor::MapValueObject() { anfnode_to_adjoin_[node] = adjoint; continue; } - // Skip Primitive. - if (IsValueNode(node) && GetValueNode(node) == prim::kPrimReturn) { - continue; + + AdjointPtr adjoint = nullptr; + if (IsValueNode(node)) { // Primitive. + if (GetValueNode(node) == prim::kPrimReturn) { + continue; + } + MS_LOG(DEBUG) << "Map Primitive node " << node->DebugString() << "."; + auto &users = manager->node_users()[node]; + if (users.size() == 0) { + MS_LOG(ERROR) << "\"" << node->DebugString() << "\" has no user."; + continue; + } else if (users.size() > 1) { + MS_LOG(DEBUG) << "\"" << node->DebugString() << "\" supposed to be used once, but users size: " << users.size(); + } + auto cnode = users.begin()->first->cast(); // We just use the first user. + auto index = users.begin()->second; + adjoint = std::make_shared(node, MapPrimitiveToK(cnode, index), tape_); + } else if (IsValueNode(node)) { // FuncGraph + MS_LOG(DEBUG) << "Map FuncGraph node " << node->DebugString() << "."; + adjoint = std::make_shared(node, MapFuncGraphToK(node), tape_); + } else if (node->isa()) { // Parameter, hardly reach here. + MS_LOG(DEBUG) << "Map Parameter node " << node->DebugString() << "."; + adjoint = std::make_shared(node, MapParameterToK(node), tape_); + } else { + adjoint = std::make_shared(node, node, tape_); } - MS_LOG(DEBUG) << "MapValueObject node " << node->ToString() << "."; - auto adjoint = std::make_shared(node, MapToK(node), tape_); UpdateAdjoint(adjoint); anfnode_to_adjoin_[node] = adjoint; } diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h index 83ca4fa41d..5dceae57f8 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h @@ -81,12 +81,12 @@ class DFunctor : public std::enable_shared_from_this { void BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint); AnfNodePtr AttachFvDoutToTape(const AnfNodePtr &grad_fv); AnfNodePtr AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv); - // Map AnfNode object from D category to K category. - AnfNodePtr MapToK(const AnfNodePtr &primal); - // Map CNode object from D category to K category. - AnfNodePtr MapToK(const CNodePtr &primal_user, size_t index); - // Map FuncGraph object from D category to K category. - AnfNodePtr MapToK(const FuncGraphPtr &primal); + // Map CNode/Index of Primitive to K. + AnfNodePtr MapPrimitiveToK(const CNodePtr &primitive_user, size_t index); + // Map ValueNode of FuncGraph to K. + AnfNodePtr MapFuncGraphToK(const AnfNodePtr &primal); + // Map ValueNode of Parameter to K. + AnfNodePtr MapParameterToK(const AnfNodePtr &primal); // MapObject impls. void MapFvObject(); void MapValueObject();