|
|
|
@@ -232,22 +232,15 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { |
|
|
|
for (size_t i = 0; i < cnode_morph->size(); i++) { |
|
|
|
auto node = cnode_morph->input(i); |
|
|
|
AdjointPtr node_adjoint = nullptr; |
|
|
|
AnfNodePtr k = nullptr; |
|
|
|
if (IsValueNode<Primitive>(node)) { |
|
|
|
k = MapToK(cnode_morph, i); |
|
|
|
node_adjoint = std::make_shared<Adjoint>(node, k, tape_); |
|
|
|
anfnode_to_adjoin_[node] = node_adjoint; |
|
|
|
auto node_adjoint_iter = anfnode_to_adjoin_.find(node); |
|
|
|
if (node_adjoint_iter != anfnode_to_adjoin_.end()) { |
|
|
|
node_adjoint = node_adjoint_iter->second; |
|
|
|
} else { |
|
|
|
auto node_adjoint_iter = anfnode_to_adjoin_.find(node); |
|
|
|
if (node_adjoint_iter != anfnode_to_adjoin_.end()) { |
|
|
|
node_adjoint = node_adjoint_iter->second; |
|
|
|
} else { |
|
|
|
// Input might be a CNode that needs to be handled previously. |
|
|
|
node_adjoint = MapMorphism(node); |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(node_adjoint); |
|
|
|
k = node_adjoint->k(); |
|
|
|
// Input might be a CNode that needs to be handled previously. |
|
|
|
node_adjoint = MapMorphism(node); |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(node_adjoint); |
|
|
|
AnfNodePtr k = node_adjoint->k(); |
|
|
|
if (k == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "MapMorphism adjoint node does not exist, input[" << i << "] " << node->ToString() << "."; |
|
|
|
} |
|
|
|
@@ -537,93 +530,69 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
// Map func graph to K |
|
|
|
AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) { |
|
|
|
auto f = func_graph_to_functor_.find(primal); |
|
|
|
// Construct representation graph for {CNode, Index} of Primitive. |
|
|
|
AnfNodePtr DFunctor::MapPrimitiveToK(const CNodePtr &primitive_user, size_t index) { |
|
|
|
auto primal = primitive_user->input(index); |
|
|
|
if (!IsValueNode<Primitive>(primal)) { |
|
|
|
MS_LOG(EXCEPTION) << "Primal graph \"" << primal->ToString() << "\" is not a ValueNode of Primitive."; |
|
|
|
} |
|
|
|
ScopeGuard scope_guard(primal->scope()); |
|
|
|
// Map Primitive to K |
|
|
|
auto value_node = primal->cast<ValueNodePtr>(); |
|
|
|
auto prim = GetValueNode<PrimitivePtr>(value_node); |
|
|
|
if (prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) { |
|
|
|
MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << "."; |
|
|
|
need_cut_ = true; |
|
|
|
} |
|
|
|
auto k_prim = g_k_prims.KPrimitive(primitive_user, value_node, resources_); |
|
|
|
if (k_prim != nullptr) { |
|
|
|
return NewValueNode(k_prim); |
|
|
|
} |
|
|
|
// When failed to find k_prim, try k_meta. |
|
|
|
auto k_meta = g_k_prims.KMetaFuncGraph(prim); |
|
|
|
if (k_meta != nullptr) { |
|
|
|
return NewValueNode(k_meta); |
|
|
|
} |
|
|
|
MS_LOG(EXCEPTION) << "Fail to map Primitive of \"" << primal->ToString() << "\" to K."; |
|
|
|
} |
|
|
|
|
|
|
|
// Construct representation graph for ValueNode of FuncGraph. |
|
|
|
AnfNodePtr DFunctor::MapFuncGraphToK(const AnfNodePtr &primal) { |
|
|
|
if (!IsValueNode<FuncGraph>(primal)) { |
|
|
|
MS_LOG(EXCEPTION) << "Primal graph \"" << primal->ToString() << "\" is not a ValueNode of FuncGraph."; |
|
|
|
} |
|
|
|
ScopeGuard scope_guard(primal->scope()); |
|
|
|
// Map func graph to K |
|
|
|
auto func_graph = GetValueNode<FuncGraphPtr>(primal); |
|
|
|
auto f = func_graph_to_functor_.find(func_graph); |
|
|
|
if (f != func_graph_to_functor_.end()) { |
|
|
|
MS_LOG(DEBUG) << "K graph functor already exist " << primal->ToString() << "."; |
|
|
|
MS_LOG(DEBUG) << "K graph functor already exist " << func_graph->ToString() << "."; |
|
|
|
return NewValueNode(f->second->k_graph_); |
|
|
|
} |
|
|
|
|
|
|
|
auto k_user_defined = KUserDefined(primal); |
|
|
|
auto k_user_defined = KUserDefined(func_graph); |
|
|
|
if (k_user_defined != nullptr) { |
|
|
|
MS_LOG(DEBUG) << "K graph functor user defined bprop " << primal->ToString() << "."; |
|
|
|
MS_LOG(DEBUG) << "K graph functor user defined bprop " << func_graph->ToString() << "."; |
|
|
|
return NewValueNode(k_user_defined); |
|
|
|
} |
|
|
|
|
|
|
|
auto functor = std::make_shared<DFunctor>(primal, resources_); |
|
|
|
auto functor = std::make_shared<DFunctor>(func_graph, resources_); |
|
|
|
functor->Init(); |
|
|
|
functor->MapObject(); |
|
|
|
functor->MapMorphism(); |
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "K graph K function graph " << primal->ToString() << " " << functor->k_graph_->ToString() << "."; |
|
|
|
MS_LOG(DEBUG) << "Map \"" << func_graph->ToString() << "\" to \"" << functor->k_graph_->ToString() << "\""; |
|
|
|
return NewValueNode(functor->k_graph_); |
|
|
|
} |
|
|
|
|
|
|
|
// Construct representation graph for primitive CNode. |
|
|
|
AnfNodePtr DFunctor::MapToK(const CNodePtr &primal_user, size_t index) { |
|
|
|
auto primal = primal_user->input(index); |
|
|
|
ScopeGuard scope_guard(primal->scope()); |
|
|
|
// Map primitive to K |
|
|
|
if (IsValueNode<Primitive>(primal)) { |
|
|
|
auto value_node = primal->cast<ValueNodePtr>(); |
|
|
|
auto prim = GetValueNode<PrimitivePtr>(value_node); |
|
|
|
if (prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) { |
|
|
|
MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << "."; |
|
|
|
need_cut_ = true; |
|
|
|
} |
|
|
|
auto k_prim = g_k_prims.KPrimitive(primal_user, value_node, resources_); |
|
|
|
if (k_prim != nullptr) { |
|
|
|
return NewValueNode(k_prim); |
|
|
|
} |
|
|
|
// When failed to find k_prim, try k_meta. |
|
|
|
auto k_meta = g_k_prims.KMetaFuncGraph(prim); |
|
|
|
if (k_meta != nullptr) { |
|
|
|
return NewValueNode(k_meta); |
|
|
|
} |
|
|
|
// Construct for ValueNode of Parameter. |
|
|
|
AnfNodePtr DFunctor::MapParameterToK(const AnfNodePtr &primal) { |
|
|
|
if (!primal->isa<Parameter>()) { |
|
|
|
MS_LOG(EXCEPTION) << "Primal graph \"" << primal->ToString() << "\" is not a ValueNode of Parameter."; |
|
|
|
} |
|
|
|
return MapToK(primal); |
|
|
|
} |
|
|
|
|
|
|
|
// Construct representation graph for given node. |
|
|
|
AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) { |
|
|
|
ScopeGuard scope_guard(primal->scope()); |
|
|
|
// Map primitive to K |
|
|
|
if (IsValueNode<Primitive>(primal)) { |
|
|
|
auto value_node = primal->cast<ValueNodePtr>(); |
|
|
|
auto prim = GetValueNode<PrimitivePtr>(value_node); |
|
|
|
if (prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) { |
|
|
|
MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << "."; |
|
|
|
need_cut_ = true; |
|
|
|
} |
|
|
|
auto k_prim = g_k_prims.KPrimitive(nullptr, value_node, resources_); |
|
|
|
if (k_prim != nullptr) { |
|
|
|
return NewValueNode(k_prim); |
|
|
|
} |
|
|
|
// When failed to find k_prim, try k_meta. |
|
|
|
auto k_meta = g_k_prims.KMetaFuncGraph(prim); |
|
|
|
if (k_meta != nullptr) { |
|
|
|
return NewValueNode(k_meta); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Map func graph to K |
|
|
|
if (IsValueNode<FuncGraph>(primal)) { |
|
|
|
auto func_graph = GetValueNode<FuncGraphPtr>(primal); |
|
|
|
auto k_func = MapToK(func_graph); |
|
|
|
return k_func; |
|
|
|
} |
|
|
|
|
|
|
|
if (primal->isa<Parameter>()) { |
|
|
|
TraceGuard trace_guard(std::make_shared<TraceGradFprop>(primal->debug_info())); |
|
|
|
auto ret = k_graph_->add_parameter(); |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
if (!primal->isa<ValueNode>()) { |
|
|
|
MS_LOG(EXCEPTION) << "K node keeped node from primal_graph_ " << primal->ToString() << " that is not a ValueNode."; |
|
|
|
} |
|
|
|
return primal; |
|
|
|
// Map Parameter to K |
|
|
|
TraceGuard trace_guard(std::make_shared<TraceGradFprop>(primal->debug_info())); |
|
|
|
auto ret = k_graph_->add_parameter(); |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
bool DFunctor::IsInScope(const AnfNodePtr &node) { |
|
|
|
@@ -664,7 +633,7 @@ void DFunctor::MapParamObject() { |
|
|
|
for (auto &p : primal_graph_->parameters()) { |
|
|
|
ScopeGuard scope_guard(p->scope()); |
|
|
|
MS_LOG(DEBUG) << "MapParamObject parameter " << p->ToString() << "."; |
|
|
|
auto adjoint = std::make_shared<Adjoint>(p, MapToK(p), tape_); |
|
|
|
auto adjoint = std::make_shared<Adjoint>(p, MapParameterToK(p), tape_); |
|
|
|
UpdateAdjoint(adjoint); |
|
|
|
anfnode_to_adjoin_[p] = adjoint; |
|
|
|
} |
|
|
|
@@ -682,12 +651,32 @@ void DFunctor::MapValueObject() { |
|
|
|
anfnode_to_adjoin_[node] = adjoint; |
|
|
|
continue; |
|
|
|
} |
|
|
|
// Skip Primitive. |
|
|
|
if (IsValueNode<Primitive>(node) && GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn) { |
|
|
|
continue; |
|
|
|
|
|
|
|
AdjointPtr adjoint = nullptr; |
|
|
|
if (IsValueNode<Primitive>(node)) { // Primitive. |
|
|
|
if (GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "Map Primitive node " << node->DebugString() << "."; |
|
|
|
auto &users = manager->node_users()[node]; |
|
|
|
if (users.size() == 0) { |
|
|
|
MS_LOG(ERROR) << "\"" << node->DebugString() << "\" has no user."; |
|
|
|
continue; |
|
|
|
} else if (users.size() > 1) { |
|
|
|
MS_LOG(DEBUG) << "\"" << node->DebugString() << "\" supposed to be used once, but users size: " << users.size(); |
|
|
|
} |
|
|
|
auto cnode = users.begin()->first->cast<CNodePtr>(); // We just use the first user. |
|
|
|
auto index = users.begin()->second; |
|
|
|
adjoint = std::make_shared<Adjoint>(node, MapPrimitiveToK(cnode, index), tape_); |
|
|
|
} else if (IsValueNode<FuncGraph>(node)) { // FuncGraph |
|
|
|
MS_LOG(DEBUG) << "Map FuncGraph node " << node->DebugString() << "."; |
|
|
|
adjoint = std::make_shared<Adjoint>(node, MapFuncGraphToK(node), tape_); |
|
|
|
} else if (node->isa<Parameter>()) { // Parameter, hardly reach here. |
|
|
|
MS_LOG(DEBUG) << "Map Parameter node " << node->DebugString() << "."; |
|
|
|
adjoint = std::make_shared<Adjoint>(node, MapParameterToK(node), tape_); |
|
|
|
} else { |
|
|
|
adjoint = std::make_shared<Adjoint>(node, node, tape_); |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "MapValueObject node " << node->ToString() << "."; |
|
|
|
auto adjoint = std::make_shared<Adjoint>(node, MapToK(node), tape_); |
|
|
|
UpdateAdjoint(adjoint); |
|
|
|
anfnode_to_adjoin_[node] = adjoint; |
|
|
|
} |
|
|
|
|