|
|
|
@@ -71,6 +71,11 @@ void DFunctor::Init(bool is_top) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void DFunctor::Finish() { |
|
|
|
CallDoutHoleOnTape(); |
|
|
|
EliminatePrimalGraph(); |
|
|
|
} |
|
|
|
|
|
|
|
void DFunctor::Clear() { |
|
|
|
func_graph_to_functor_.clear(); |
|
|
|
anfnode_to_adjoin_definition_.clear(); |
|
|
|
@@ -728,10 +733,7 @@ void DFunctor::CallDoutHoleOnTape() { |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
FuncGraphPtr DFunctor::k_graph() { |
|
|
|
CallDoutHoleOnTape(); |
|
|
|
return k_graph_; |
|
|
|
} |
|
|
|
FuncGraphPtr DFunctor::k_graph() { return k_graph_; } |
|
|
|
|
|
|
|
void DFunctor::BroadCastStopFlag() { |
|
|
|
// As stop set expanding, all directly or indirectly stopped CNode will be cut off |
|
|
|
@@ -768,5 +770,28 @@ bool DFunctor::AllReferencesStopped(const CNodePtr &node) { |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
// To replace the primal graph with k graph |
|
|
|
void DFunctor::EliminatePrimalGraph() { |
|
|
|
auto k_vnode = NewValueNode(k_graph_); |
|
|
|
auto idx0 = NewValueNode(SizeToInt(0)); |
|
|
|
auto imm0 = std::make_shared<Int32Imm>(0); |
|
|
|
idx0->set_abstract(std::make_shared<abstract::AbstractScalar>(imm0)); |
|
|
|
auto manager = primal_graph_->manager(); |
|
|
|
auto users = primal_graph_->func_graph_cnodes_index(); |
|
|
|
for (auto &it : users) { |
|
|
|
auto cnode = it.first->first->cast<CNodePtr>(); |
|
|
|
auto index = it.first->second; |
|
|
|
auto vnode = cnode->inputs()[index]; |
|
|
|
if (index != 0) { |
|
|
|
MS_LOG(INFO) << "Primal is used but not called, at {" << cnode->DebugString(3) << "/" << index << "}"; |
|
|
|
continue; |
|
|
|
} |
|
|
|
cnode->set_input(0, k_vnode); // Replace primal graph with k graph |
|
|
|
auto construct_wrapper = cnode->func_graph(); |
|
|
|
auto getitem0 = construct_wrapper->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx0}); |
|
|
|
manager->Replace(cnode, getitem0); |
|
|
|
} |
|
|
|
} |
|
|
|
} // namespace ad |
|
|
|
} // namespace mindspore |