|
|
|
@@ -593,9 +593,17 @@ void DfGraphConvertor::TraceOutput(const AnfNodePtr node) { |
|
|
|
AnfNodePtr anf_out = node; |
|
|
|
AnfNodePtr pre_node = nullptr; |
|
|
|
|
|
|
|
// trace Parameter node |
|
|
|
// Trace value node |
|
|
|
if (node->isa<ValueNode>()) { |
|
|
|
auto op = Convert(anf_out); |
|
|
|
graph_outputs_.emplace_back(std::make_pair(*op, "")); |
|
|
|
AddGraphConstInput(op); |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
// Trace Parameter node |
|
|
|
TraceOutputFromParameter(anf_out); |
|
|
|
// then trace cnode |
|
|
|
// Then trace cnode |
|
|
|
if (!node->isa<CNode>()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
@@ -869,7 +877,12 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "trace output"; |
|
|
|
graph_outputs_.clear(); |
|
|
|
TraceOutput(anf_graph_->get_return()->input(1)); |
|
|
|
|
|
|
|
// Add const nodes as graph input for some operator work with constant |
|
|
|
MS_LOG(INFO) << "graph const input size: " << graph_const_inputs_.size(); |
|
|
|
std::transform(graph_const_inputs_.begin(), graph_const_inputs_.end(), std::back_inserter(inputs), |
|
|
|
[](OperatorPtr x) { return *x; }); |
|
|
|
|
|
|
|
@@ -879,8 +892,6 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() { |
|
|
|
// set graph output |
|
|
|
// set the value of finale return apply node as the output of dataflow graph |
|
|
|
MS_LOG(DEBUG) << "set output"; |
|
|
|
graph_outputs_.clear(); |
|
|
|
TraceOutput(anf_graph_->get_return()->input(1)); |
|
|
|
MS_LOG(INFO) << "set graph output num: " << graph_outputs_.size(); |
|
|
|
(void)df_graph_->SetOutputs(graph_outputs_); |
|
|
|
|
|
|
|
@@ -1036,7 +1047,7 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node |
|
|
|
} |
|
|
|
|
|
|
|
void DfGraphConvertor::AddGraphConstInput(const OperatorPtr &op) { |
|
|
|
if (op->GetOpType() == "Constant") { |
|
|
|
if (op->GetOpType() == "Constant" || op->GetOpType() == "Const") { |
|
|
|
graph_const_inputs_.push_back(op); |
|
|
|
} |
|
|
|
} |
|
|
|
|