|
|
|
@@ -53,6 +53,37 @@ using Constant = ge::op::Constant; |
|
|
|
using Assign = ge::op::Assign; |
|
|
|
using Data = ge::op::Data; |
|
|
|
|
|
|
|
namespace { |
|
|
|
std::vector<AnfNodePtr> GetOrderedCNodes(const FuncGraphPtr fg) { |
|
|
|
auto BelongSameGraph = std::bind(IncludeBelongGraph, fg, std::placeholders::_1); |
|
|
|
auto succ_include_fv = [&fg](const AnfNodePtr &node) -> std::vector<AnfNodePtr> { |
|
|
|
std::vector<AnfNodePtr> vecs; |
|
|
|
if (node == nullptr) { |
|
|
|
return vecs; |
|
|
|
} |
|
|
|
if (node->isa<CNode>()) { |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
auto &inputs = cnode->inputs(); |
|
|
|
// Check if free variables used. |
|
|
|
for (const auto &input : inputs) { |
|
|
|
auto input_fg = GetValueNode<FuncGraphPtr>(input); |
|
|
|
if (input_fg) { |
|
|
|
for (auto &fv : input_fg->free_variables_nodes()) { |
|
|
|
if (fv->func_graph() == fg && fg->nodes().contains(fv)) { |
|
|
|
vecs.push_back(fv); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
(void)vecs.insert(vecs.end(), inputs.begin(), inputs.end()); |
|
|
|
} |
|
|
|
return vecs; |
|
|
|
}; |
|
|
|
|
|
|
|
return TopoSort(fg->get_return(), succ_include_fv, BelongSameGraph); |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
// ---------------implement of DfGraphConvertor------------- |
|
|
|
PrimType GetCNodeFuncType(const CNodePtr cnode) { |
|
|
|
if (cnode->inputs().empty()) { |
|
|
|
@@ -214,7 +245,7 @@ void DfGraphConvertor::DrawParamInitSubGraph(const std::string &name, const AnfN |
|
|
|
|
|
|
|
void DfGraphConvertor::SetupParamInitSubGraph(const TensorOrderMap &tensors, std::vector<ge::Operator> *init_input) { |
|
|
|
DfGraphPtr init_graph = std::make_shared<DfGraph>("init"); |
|
|
|
std::vector<AnfNodePtr> nodes = TopoSort(anf_graph_->get_return()); |
|
|
|
std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph_); |
|
|
|
|
|
|
|
for (auto &it : nodes) { |
|
|
|
if (it->isa<ValueNode>()) { |
|
|
|
@@ -549,7 +580,7 @@ DfGraphConvertor &DfGraphConvertor::ConvertAllNode() { |
|
|
|
|
|
|
|
// Convert all anf node to Operator |
|
|
|
MS_LOG(DEBUG) << "convert all node"; |
|
|
|
std::vector<AnfNodePtr> nodes = TopoSort(anf_graph_->get_return()); |
|
|
|
std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph_); |
|
|
|
for (auto &it : nodes) { |
|
|
|
(void)Convert(it); |
|
|
|
if (this->error_ != 0) { |
|
|
|
@@ -811,7 +842,7 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() { |
|
|
|
} |
|
|
|
|
|
|
|
// Case node set input. |
|
|
|
std::vector<AnfNodePtr> nodes = ::mindspore::TopoSort(anf_graph_->get_return()); |
|
|
|
std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph_); |
|
|
|
for (auto &it : nodes) { |
|
|
|
if (it->isa<CNode>() && IsCaseNode(it->cast<CNodePtr>())) { |
|
|
|
auto node = it->cast<CNodePtr>(); |
|
|
|
@@ -825,7 +856,7 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() { |
|
|
|
|
|
|
|
// set up dependencies |
|
|
|
MS_LOG(DEBUG) << "set up dependencies"; |
|
|
|
nodes = ::mindspore::TopoSort(anf_graph_->get_return()); |
|
|
|
nodes = GetOrderedCNodes(anf_graph_); |
|
|
|
for (auto &it : nodes) { |
|
|
|
SetNodeInput(it); |
|
|
|
SetOpControlInput(it); |
|
|
|
@@ -1195,6 +1226,51 @@ void DfGraphConvertor::SetTupleOpInput(const OpAdapterPtr &adpt, const CNodePtr |
|
|
|
} |
|
|
|
MS_LOG(WARNING) << "This anf node is not supported as a tuple item : " << node->ToString(); |
|
|
|
} |
|
|
|
AnfNodePtr DfGraphConvertor::GetRealInputNode(const CNodePtr &node, const AnfNodePtr &input) { |
|
|
|
if (input == nullptr || node == nullptr) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
AnfNodePtr pred = input; |
|
|
|
while (pred->isa<CNode>() && GetCNodeTargetFuncName(pred->cast<CNodePtr>()) == prim::kPrimDepend->name()) { |
|
|
|
pred = pred->cast<CNodePtr>()->input(1); |
|
|
|
} |
|
|
|
|
|
|
|
// skip input of UMonad, IOMonad |
|
|
|
if (IsValueNode<UMonad>(pred) || IsValueNode<IOMonad>(pred)) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
// skip input of the None, UpdateState |
|
|
|
if (IsValueNode<None>(pred) || IsPrimitiveCNode(pred, prim::kPrimUpdateState)) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
if (IsPrimitiveCNode(pred, prim::kPrimLoad)) { |
|
|
|
pred = ParseLoadInput(pred->cast<CNodePtr>()); |
|
|
|
} |
|
|
|
|
|
|
|
// transform "Const" op to "Variable" op when the next node is "Assign" op. |
|
|
|
std::string c_name = GetCNodeTargetFuncName(node); |
|
|
|
auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name); |
|
|
|
if (!training_ && pos != trans_var_list.end() && pred->isa<Parameter>()) { |
|
|
|
std::string name = std::static_pointer_cast<Parameter>(pred)->name(); |
|
|
|
auto op_itor = op_cache_.find(pred.get()); |
|
|
|
if (op_itor == op_cache_.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Can not find op for node " << pred->ToString() << "."; |
|
|
|
} |
|
|
|
if (op_itor->second != nullptr && |
|
|
|
(op_itor->second->GetOpType() == "Constant" || op_itor->second->GetOpType() == "Const") && |
|
|
|
vars_.find(name) != vars_.end()) { |
|
|
|
auto variable = std::make_shared<Variable>(name); |
|
|
|
auto desc = vars_[name]->GetOutputDesc("y"); |
|
|
|
(void)variable->update_output_desc_y(desc); |
|
|
|
MS_LOG(DEBUG) << "Trans to variable, var = " << variable->GetName() << "."; |
|
|
|
op_itor->second = variable; // replace parameter with variable |
|
|
|
vars_[name] = variable; |
|
|
|
} |
|
|
|
} |
|
|
|
return pred; |
|
|
|
} |
|
|
|
|
|
|
|
void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node) { |
|
|
|
OperatorPtr src = Convert(node); |
|
|
|
@@ -1213,45 +1289,11 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node |
|
|
|
} else { |
|
|
|
pred = inputs[i]; |
|
|
|
} |
|
|
|
|
|
|
|
while (pred->isa<CNode>() && GetCNodeTargetFuncName(pred->cast<CNodePtr>()) == prim::kPrimDepend->name()) { |
|
|
|
pred = pred->cast<CNodePtr>()->input(1); |
|
|
|
} |
|
|
|
|
|
|
|
// skip input of UMonad, IOMonad |
|
|
|
if (IsValueNode<UMonad>(pred) || IsValueNode<IOMonad>(pred)) { |
|
|
|
pred = GetRealInputNode(node, pred); |
|
|
|
if (pred == nullptr) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
// skip input of the None, Load, UpdateState |
|
|
|
if (IsValueNode<None>(pred) || IsPrimitiveCNode(pred, prim::kPrimUpdateState)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
if (IsPrimitiveCNode(pred, prim::kPrimLoad)) { |
|
|
|
pred = ParseLoadInput(pred->cast<CNodePtr>()); |
|
|
|
} |
|
|
|
|
|
|
|
// transform "Const" op to "Variable" op when the next node is "Assign" op. |
|
|
|
std::string c_name = GetCNodeTargetFuncName(node); |
|
|
|
auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name); |
|
|
|
if (!training_ && pos != trans_var_list.end() && pred->isa<Parameter>()) { |
|
|
|
std::string name = std::static_pointer_cast<Parameter>(pred)->name(); |
|
|
|
auto op_itor = op_cache_.find(pred.get()); |
|
|
|
if (op_itor == op_cache_.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Can not find op for node " << pred->ToString() << "."; |
|
|
|
} |
|
|
|
if (op_itor->second != nullptr && |
|
|
|
(op_itor->second->GetOpType() == "Constant" || op_itor->second->GetOpType() == "Const") && |
|
|
|
vars_.find(name) != vars_.end()) { |
|
|
|
auto variable = std::make_shared<Variable>(name); |
|
|
|
auto desc = vars_[name]->GetOutputDesc("y"); |
|
|
|
(void)variable->update_output_desc_y(desc); |
|
|
|
MS_LOG(DEBUG) << "Trans to variable, var = " << variable->GetName() << "."; |
|
|
|
op_itor->second = variable; // replace parameter with variable |
|
|
|
vars_[name] = variable; |
|
|
|
} |
|
|
|
} |
|
|
|
int index = SizeToInt(i); |
|
|
|
// find in out_hadnle_cache_ first |
|
|
|
auto it = out_handle_cache_.find(pred.get()); |
|
|
|
|